diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 6fca800c90..42a9d5148a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,6 +115,7 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -123,6 +124,7 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 269af4e6a7..ce35c6a2a7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -12,6 +12,7 @@ from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file DTYPE_BITS = { @@ -83,6 +84,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, + {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -97,7 +99,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -161,12 +163,19 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; }} """ +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -177,8 +186,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -221,6 +230,7 @@ class FmhaFwdApiTrait: dpad : str dvpad : str skip : str + tr_load : str constraint : CppConstraint @property @@ -231,13 +241,19 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -248,6 +264,9 @@ class FmhaFwdApiTrait: elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' else: assert False @property @@ -256,7 +275,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -268,7 +287,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -290,6 +309,7 @@ class FmhaFwdPipeline: F_squant : str # F_mask : str # value from MASK_MAP F_skip : str # true/false + F_trload : str # true/false F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -331,6 +351,9 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @@ -351,31 +374,39 @@ class FmhaFwdApiPool: @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) @dataclass class FmhaFwdTileSize: @@ -458,7 +489,8 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) @property def name(self) -> str: @@ -494,6 +526,7 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) class KernelComponentFactory: @@ -503,10 +536,15 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], @@ -534,34 +572,27 @@ class KernelComponentFactory: if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -599,6 +630,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -665,10 +702,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index c0e4dc3d30..d0f8e3798c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1135,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + << " GB/s" << std::flush << std::endl; if(do_validation == 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 81dda692ea..df1e9e5699 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -1028,6 +1029,7 @@ template struct fmha_fwd_traits_ { @@ -1052,6 +1054,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 599c595a75..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,14 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -for perm in 0 1 ; do - -$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 - -done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index b867cd6c07..dc2be933bd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -42,7 +42,6 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do - for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -51,16 +50,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 35da19cd3e..07be65a150 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1318,6 +1314,17 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template & src_thread_ #if defined(__gfx950__) template -__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 8c3bc0bc36..c64b296408 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } @@ -1186,6 +1182,17 @@ enum struct amd_buffer_coherence_enum glc = 1, slc = 2, glc_slc = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template -__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index ab42ec8617..f0e9518120 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -89,21 +89,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } -CK_TILE_DEVICE void block_sync_lds() -{ -#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - // asm volatile("\ - // s_waitcnt lgkmcnt(0) \n \ - // s_barrier \ - // " ::); - - __builtin_amdgcn_s_waitcnt(0xc07f); - __builtin_amdgcn_s_barrier(); -#else - __syncthreads(); -#endif -} - CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) { #ifdef __gfx12__ @@ -174,6 +159,18 @@ CK_TILE_DEVICE void s_waitcnt_barrier() __builtin_amdgcn_s_barrier(); } +template +CK_TILE_DEVICE void block_sync_lds() +{ + s_waitcnt_barrier(); +} + +template +CK_TILE_DEVICE void block_sync_lds_direct_load() +{ + s_waitcnt_barrier(); +} + CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 7184f99521..93008f8525 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } +template +CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) +{ + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const int32x2_t x = __builtin_amdgcn_permlane32_swap( + bit_cast(v_local), bit_cast(v_local), false, false); + + thread_buffer v; + v(0) = bit_cast(x[0]); + v(1) = bit_cast(x[1]); + + return v; +} + template CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) { diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index c471f416c3..e472bd01e5 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -191,6 +191,16 @@ #endif #endif +// use llvm builtin bf16 data type after ROCm 6.5 +#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \ + (HIP_VERSION_MAJOR >= 7) +#define CK_TILE_USE_LLVM_BUILTIN_BF16 1 +#else +#define CK_TILE_USE_LLVM_BUILTIN_BF16 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 6f31468809..245fb7244f 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,6 +6,9 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#if CK_TILE_USE_LLVM_BUILTIN_BF16 +#include +#endif #include #pragma once @@ -102,7 +105,11 @@ struct native_t using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; #else +#if CK_TILE_USE_LLVM_BUILTIN_BF16 +using bfloat16_t = __bf16; +#else using bfloat16_t = ushort; +#endif using bf16_t = bfloat16_t; using bf16_raw_t = uint16_t; #endif @@ -280,7 +287,11 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { +#if defined(__gfx950__) + return static_cast(f); +#else return bit_cast(float_to_bf16_raw(f, constant{})); +#endif } template using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); -using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 58bdb43b08..bbd3d53827 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); // bf16 // using bf16_t = ... -using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); -using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); -using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); -using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); -using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); -using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64))); // i32 // using int32_t = ... diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 0e385901ed..b69c167315 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -330,13 +330,6 @@ struct PassThrough y = type_convert(x); } - template <> - CK_TILE_HOST_DEVICE void - operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const - { - y = type_convert(x); - } - template <> CK_TILE_HOST_DEVICE void operator()(float& y, const ck_tile::fp16_t& x) const diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index d8dd5db12e..69f645b850 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -52,6 +52,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 8d257a3329..5b3d38d3e7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -13,6 +13,7 @@ #include #include +#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -61,6 +62,14 @@ struct FmhaFwdKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; +#if defined(__gfx950__) + static constexpr bool kIsAvialable = true; +#else + static constexpr bool kIsAvialable = !kUseTrLoad; +#endif + static constexpr std::string_view kPipelineName = FmhaPipeline::name; + // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; @@ -100,7 +109,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -1036,455 +1045,1142 @@ struct FmhaFwdKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + if constexpr(kIsAvialable) + run_(std::move(kargs)); + } - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - - if constexpr(kIsGroupMode) + CK_TILE_DEVICE void run_(Kargs kargs) const + { + if constexpr(kPipelineName != "qr_async_trload") { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } - batch_offset_o = query_start * kargs.stride_o; + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - if constexpr(kSkipMinSeqlenQ) + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) { - if(kargs.seqlen_q <= kargs.min_seqlen_q) + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) { return; } - } - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = + static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } - } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - } - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; - const VDataType* v_ptr = - reinterpret_cast(kargs.v_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - batch_offset_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; - // Q/K/V DRAM and DRAM window - const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - number{}, - number<1>{}); - if constexpr(FmhaPipeline::kQLoadOnce) - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - q_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, number<1>{}); - - const auto v_dram_transposed = - transform_tensor_view(v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. + /// Remove following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = + make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = + make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host + ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } + else + { + // TODO: Refine the logical here. + // In Decode case + // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache + // 2. limit the LDS usage, as we want higher occupancy + // In Prefill case + // 1. we expect KV data reused by different ThreadGroups, use cache + // 2. use more LDS, as we want better memory latency hiding + // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the + // cache + constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128; + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; + const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; // unused for paged-kvcache + long_index_t batch_offset_v = 0; // unused for paged-kvcache + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + // index_t kv_l2p_offset = + // 0; // logical-to-physical offset of seqlen_k coordinate. only used for + // paged-kvcache + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + kargs.seqlen_k = + kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; + } } else { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = + static_cast(i_batch) * kargs.batch_stride_bias; + } + } + + // for simplicity, batch stride we just modify the pointer + const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk; + + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead_k) * kargs.nhead_stride_v + + batch_offset_v; + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&] { + const auto q_dram_naive = [&] { + { + return make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + } + }(); + + if constexpr(FmhaPipeline::kQLoadOnce) + { + const auto seqlen_q = kargs.seqlen_q; + const auto q_dram_pad = pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(QDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto q_dram_unmerged = transform_tensor_view( + q_dram_pad, + make_tuple( + make_unmerge_transform( + make_tuple(seqlen_q / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto q_dram_merged = transform_tensor_view( + q_dram_unmerged, + make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto q_dram_unmerged_xor = transform_tensor_view( + q_dram_merged, + make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto q_dram_permuted = transform_tensor_view( + q_dram_unmerged_xor, + make_tuple( + make_xor_transform( + make_tuple(seqlen_q / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto q_dram_tmp = transform_tensor_view( + q_dram_permuted, + make_tuple( + make_pass_through_transform(seqlen_q / XorLengthFold), + make_unmerge_transform( + make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + q_dram_tmp, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(seqlen_q / XorLengthFold, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto q_dram_unmerged = transform_tensor_view( + q_dram_pad, + make_tuple( + make_pass_through_transform(seqlen_q), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto q_dram_permuted = transform_tensor_view( + q_dram_unmerged, + make_tuple( + make_xor_transform(make_tuple(seqlen_q, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_view( + q_dram_permuted, + make_tuple( + make_pass_through_transform(seqlen_q), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto make_k_dram = [&](const KDataType* data, index_t height) { + const auto k_dram_naive = make_naive_tensor_view( + data, // will update this pointer if using paged-kvcache + make_tuple(height, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + const auto k_dram_pad = pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto k_dram_unmerged = transform_tensor_view( + k_dram_pad, + make_tuple(make_unmerge_transform( + make_tuple(height / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto k_dram_merged = transform_tensor_view( + k_dram_unmerged, + make_tuple(make_pass_through_transform(height / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto k_dram_unmerged_xor = transform_tensor_view( + k_dram_merged, + make_tuple(make_pass_through_transform(height / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto k_dram_permuted = transform_tensor_view( + k_dram_unmerged_xor, + make_tuple( + make_xor_transform( + make_tuple(height / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto k_dram_tmp = transform_tensor_view( + k_dram_permuted, + make_tuple( + make_pass_through_transform(height / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + k_dram_tmp, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(height / XorLengthFold, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto k_dram_unmerged = transform_tensor_view( + k_dram_pad, + make_tuple( + make_pass_through_transform(height), + make_unmerge_transform(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto k_dram_permuted = transform_tensor_view( + k_dram_unmerged, + make_tuple( + make_xor_transform(make_tuple( + height, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_view( + k_dram_permuted, + make_tuple( + make_pass_through_transform(height), + make_merge_transform_v3_division_mod(make_tuple( + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + }; + const auto k_dram = [&]() { + { + return make_k_dram(k_ptr, kargs.seqlen_k); + } + }(); + + const auto make_v_dram = [&](const VDataType* data, index_t length) { const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), + data, // will update this pointer if using paged-kvcache + make_tuple(length, kargs.hdim_v), + make_tuple(kargs.hdim_v, 1), number{}, number<1>{}); - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( + // TODO: Add kVHeadDim + constexpr index_t XorGroupSize = + FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + + const auto v_dram_pad = pad_tensor_view( v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } - }(); + make_tuple(number{}, number{}), + sequence{}); - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr index_t LDSLayerSize = 256 / sizeof(VDataType); + constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); + + if constexpr(XorLengthFold > 1) + { + const auto v_dram_unmerged = transform_tensor_view( + v_dram_pad, + make_tuple(make_unmerge_transform( + make_tuple(length / XorLengthFold, XorLengthFold)), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto v_dram_merged = transform_tensor_view( + v_dram_unmerged, + make_tuple(make_pass_through_transform(length / XorLengthFold), + make_merge_transform_v3_division_mod(make_tuple( + XorLengthFold, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto v_dram_unmerged_xor = transform_tensor_view( + v_dram_merged, + make_tuple( + make_pass_through_transform(length / XorLengthFold), + make_unmerge_transform(make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto v_dram_permuted = transform_tensor_view( + v_dram_unmerged_xor, + make_tuple( + make_xor_transform(make_tuple(length / XorLengthFold, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + const auto v_dram_tmp = transform_tensor_view( + v_dram_permuted, + make_tuple(make_pass_through_transform(length / XorLengthFold), + make_unmerge_transform(make_tuple( + number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_view( + v_dram_tmp, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(length / XorLengthFold, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } else - return make_tuple(number{}, number{}); - }(), - {i_m0, 0}); +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + const auto v_dram_unmerged = transform_tensor_view( + v_dram_pad, + make_tuple(make_pass_through_transform(length), + make_unmerge_transform( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + const auto v_dram_permuted = transform_tensor_view( + v_dram_unmerged, + make_tuple(make_xor_transform(make_tuple( + length, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove - /// following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); - - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; + return transform_tensor_view( + v_dram_permuted, + make_tuple(make_pass_through_transform(length), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } }; - }(); - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - - // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) + const auto v_dram = [&]() { { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); + return make_v_dram(v_ptr, kargs.seqlen_k); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {0, 0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. + /// Remove following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = + make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); } else { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + return make_null_tile_window(bias_dram_window_lengths); } - } - else - { - return EmptyPositionEncoding{}; - } - }(); + }(); - AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); + // lse acc + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + + batch_offset_lse; - BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + const auto lse_dram = [&] { + const auto lse_dram_naive = [&] { + { + return make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + } + }(); + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); - auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } - }(); + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); - // O DRAM and O DRAM window - auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - number{}, - number<1>{}); + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); - return pad_tensor_view( - o_dram_naive, + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask( + slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(PrefillCase) + { + // allocate double lds + // add __restrict__ here to avoid aliasing + __shared__ char smem_ptrk0 + [FmhaPipeline::Policy::template GetSmemSizeK()]; + __shared__ char smem_ptrk1 + [FmhaPipeline::Policy::template GetSmemSizeK()]; + __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV< + typename FmhaPipeline::Problem>()]; + __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV< + typename FmhaPipeline::Problem>()]; + + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + smem_ptrk0, + smem_ptrk1, + smem_ptrv0, + smem_ptrv1); + } + else + { + __shared__ char smem_ptr[GetSmemSize()]; + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + smem_ptr); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_dram = [&] { + const auto o_dram_naive = [&] { + { + return make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + } + }(); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, make_tuple(number{}, number{}), - sequence{}); - }(); + {i_m0, i_n1}); - auto o_dram_window = - make_tile_window(o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index aa2ec99590..f6a20c5cb5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1038,7 +1038,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1096,7 +1096,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1190,7 +1190,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1249,7 +1249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1344,7 +1344,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1379,7 +1379,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1490,7 +1490,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1589,7 +1589,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1623,7 +1623,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1667,7 +1667,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -1718,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); pt_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), pt_warp_tensor.get_thread_buffer()); }); @@ -1768,7 +1768,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); dst_out.set_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths), dst_warp_tensor.get_thread_buffer()); }); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index cf70dff63f..45a1c8f4b8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -11,6 +11,7 @@ enum class BlockFmhaPipelineEnum QRKSVS = 0, QRKSVS_ASYNC, QSKSVS, + QRKSVS_ASYNC_TRLOAD, }; template @@ -32,4 +33,10 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qs"; }; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async_trload"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 20b30b7417..86ac713b6f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -22,6 +22,7 @@ template struct BlockFmhaPipelineProblem { @@ -46,6 +47,7 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp new file mode 100644 index 0000000000..39d8814692 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -0,0 +1,1177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSAsyncTrload +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1); + static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1); + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + // Problem::kPadHeadDimV == true); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = + Problem::kPadHeadDimQ; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = + Problem::kPadHeadDimV; // support multiple of vector(like 8x) + + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasUnevenSplits = true; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO(); + + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async_trload"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + // Decode + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && + kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], + "wrong!"); + ignore = bias_dram_block_window_tmp; + ignore = position_encoding; + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + auto o_acc = OaccBlockTileType{}; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(o_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + // init M, L + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // Q tile in LDS + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); + + auto q_lds_write_view = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_read_view = make_tensor_view( + static_cast(smem_ptr), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_store_window = + make_tile_window(q_lds_write_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_read_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + + async_load_tile(q_lds_store_window, q_dram_window); + + // K tile in LDS + const index_t physical_seqlen_k_start = logical_seqlen_k_start; + const index_t physical_seqlen_k_end = logical_seqlen_k_end; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + + auto k_lds_write_view = make_tensor_view( + static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_read_view = make_tensor_view( + static_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds_write_view, + Policy::template MakeKLdsBlockDescriptor().get_lengths(), + {0, 0}); + auto k_lds_read_window = + make_tile_window(k_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeK()), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // V tile in LDS + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + + auto v_lds_write_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeS()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_read_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeS()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds_write_view, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeVRegTileDistribution()); + + block_sync_lds_direct_load<0>(); + auto q_tile = load_tile(q_lds_read_window); + + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + + constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); + constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + + do + { + block_sync_lds(); + async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile + + // move V tile windows + move_tile_window(v_dram_window, {kN0, 0}); + + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + + if constexpr(1 < k0_loops) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + if constexpr(i_k0 == 0) + { + block_sync_lds_direct_load(); + } + else + { + block_sync_lds_direct_load<0>(); + } + + auto k_tile = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_tile); + + // loop over along the [K]ey head dimension + move_tile_window(k_dram_window, {0, kK0}); + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + }); + // move back to the origin + move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)}); + } + + if constexpr(k0_loops == 1) + { + block_sync_lds_direct_load(); + } + else + { + block_sync_lds_direct_load<0>(); + } + + auto k_tile = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_tile); + + if constexpr(kHasUnevenSplits) + { + if(i_total_loops == (num_total_loop - 1)) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(I0) + tile_idx.at(I1); + + { + return physical_seqlen_k_end_ <= col; + } + }); + } + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(I0) + tile_idx.at(I0); + const auto col = k_origin.at(I0) + tile_idx.at(I1); + return mask.IsOutOfBound(row, col); + }); + } + } + + // move K tile windows after current status checked + // prefetch next-tile along [K]ey sequence length dimension + move_tile_window(k_dram_window, {kN0, 0}); + + block_sync_lds(); + async_load_tile(k_lds_write_window, k_dram_window); + + // Gemm1 + auto s_new = [&]() { + if constexpr(kNWarp > 1) + { + auto s = cast_tile(s_acc); // S{j} + + store_tile(s_write_lds_window, s); + block_sync_lds(); + return load_tile(s_read_lds_window); + } + else + { + return cast_tile(s_acc); // S{j} + } + }(); + + auto m_local = block_tile_reduce( + s_new, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + // Set CrossWarp to false will trigger better strategy on gfx950, but will cause + // performance regression because of un-coexecutable packed math, silent it for now + block_tile_reduce_sync( + m_local, f_max, bool_constant{} /*, bool_constant{}*/); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s_new.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_max = scale_s * get_validated_m(m[i_idx]); + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } + } + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync( + rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); + + auto p_tile = make_static_distributed_tensor( + Policy::template MakePRegTileDistribution()); + p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds_direct_load(); + + auto v_tile = load_tile_transpose(v_lds_read_window); + + if constexpr(1 < k1_loops) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}), + v_tile); + + // loop over along the [V]alue Sequence length + move_tile_window(v_lds_read_window, {kK1, 0}); + v_tile = load_tile_transpose(v_lds_read_window); + }); + // move back to the origin + move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); + } + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + v_tile); + + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } + } + }); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + // Prefill, double lds + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* __restrict__ smem_ptrk0, + void* __restrict__ smem_ptrk1, + void* __restrict__ smem_ptrv0, + void* __restrict__ smem_ptrv1) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] && + kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1], + "wrong!"); + ignore = bias_dram_block_window_tmp; + ignore = position_encoding; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + auto o_acc = OaccBlockTileType{}; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(o_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + // init M, L + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(I0), number{}, number{}); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // Q tile in LDS + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution()); + + auto q_lds_write_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_read_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeQLdsBlockDescriptor()); + + auto q_lds_store_window = + make_tile_window(q_lds_write_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto q_lds_read_window = + make_tile_window(q_lds_read_view, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + + async_load_tile(q_lds_store_window, q_dram_window); + block_sync_lds_direct_load<0>(); + auto q_tile = load_tile(q_lds_read_window); + + // K tile in LDS + const index_t physical_seqlen_k_start = logical_seqlen_k_start; + const index_t physical_seqlen_k_end = logical_seqlen_k_end; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + + auto k_lds_write_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_read_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); + + auto k_lds_write_window = + make_tile_window(k_lds_write_view, + Policy::template MakeKLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto k_lds_read_window = + make_tile_window(k_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptrk0) + + Policy::template GetSmemSizeK()), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // V tile in LDS + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + + auto v_lds_write_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); + + auto v_lds_read_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); + + auto v_lds_write_window = + make_tile_window(v_lds_write_view, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); + + auto v_lds_read_window = + make_tile_window(v_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeVRegTileDistribution()); + + // block_sync_lds_direct_load<0>(); + // auto q_tile = load_tile(q_lds_read_window); + + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + block_sync_lds<0>(); + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + + move_tile_window(k_dram_window, {kN0, 0}); + k_lds_write_window.set_bottom_tensor_view_data_ptr( + static_cast(smem_ptrk1)); + async_load_tile(k_lds_write_window, k_dram_window); + + constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); + constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + + constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access(); + constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access(); + + block_sync_lds_direct_load(); + auto k_tile = load_tile(k_lds_read_window); + + __builtin_amdgcn_sched_barrier(0); + + auto mainloop = [&](index_t cur_loop) { + const bool is_even_loop = (cur_loop % 2 == 0); + + auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) + : static_cast(smem_ptrk1); + auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) + : static_cast(smem_ptrk0); + auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) + : static_cast(smem_ptrv0); + auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) + : static_cast(smem_ptrv1); + + // move V tile windows + block_sync_lds(); + move_tile_window(v_dram_window, {kN0, 0}); + v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr); + async_load_tile(v_lds_write_window, v_dram_window); + + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + + if constexpr(1 < k0_loops) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + // loop over along the [K]ey head dimension + move_tile_window(k_lds_read_window, {0, kK0}); + auto k_tile_switch = load_tile(k_lds_read_window); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_tile); + + k_tile = k_tile_switch; + }); + // move back to the origin + move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); + } + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_tile); + + block_sync_lds_direct_load(); + v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr); + auto v_tile = load_tile_transpose(v_lds_read_window); + + if constexpr(kHasUnevenSplits) + { + if(i_total_loops == (num_total_loop - 1)) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(I0) + tile_idx.at(I1); + + { + return physical_seqlen_k_end_ <= col; + } + }); + } + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(I0) + tile_idx.at(I0); + const auto col = k_origin.at(I0) + tile_idx.at(I1); + return mask.IsOutOfBound(row, col); + }); + } + } + + // Gemm1 + auto s_new = [&]() { + if constexpr(kNWarp > 1) + { + auto s = cast_tile(s_acc); // S{j} + + store_tile(s_write_lds_window, s); + block_sync_lds(); + return load_tile(s_read_lds_window); + } + else + { + return cast_tile(s_acc); // S{j} + } + }(); + + auto m_local = block_tile_reduce( + s_new, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync( + m_local, f_max, bool_constant{} /*, bool_constant{}*/); + + static_for<0, 12, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s_new.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_max = scale_s * get_validated_m(m[i_idx]); + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } + } + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync( + rowsum_p, f_sum, bool_constant{} /*, bool_constant{}*/); + + auto p_tile = make_static_distributed_tensor( + Policy::template MakePRegTileDistribution()); + p_tile.get_thread_buffer() = cast_tile(p_compute).get_thread_buffer(); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + move_tile_window(k_dram_window, {kN0, 0}); + k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr); + async_load_tile(k_lds_write_window, k_dram_window); + + if constexpr(1 < k1_loops) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + // loop over along the [V]alue Sequence length + move_tile_window(v_lds_read_window, {kK1, 0}); + auto v_tile_switch = load_tile_transpose(v_lds_read_window); + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}), + v_tile); + + v_tile = v_tile_switch; + }); + // move back to the origin + move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); + } + + gemm_1(o_acc, + get_slice_tile(p_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + v_tile); + + block_sync_lds_direct_load(); + k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr); + k_tile = load_tile(k_lds_read_window); + + static_for<0, 12, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + }; + + do + { + mainloop(i_total_loops); + i_total_loops++; + } while(i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } + } + }); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, lse_acc); + } + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp new file mode 100644 index 0000000000..ed22758566 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -0,0 +1,823 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" + +// can remove all bank conflicts, but drop the performance for some cases +// Probably it is limited by compiler optimization. +#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 +namespace ck_tile { +// This pipeline is qkv all located in LDS +struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + // this should align with MakeQDramTileDistribution() + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + + return static_cast(16 / sizeof(OaccDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + if constexpr(!BypassLDS) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = + LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read M first, then K + // This is the same data consume order as BlockGEMM + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return static_cast(16 / sizeof(QDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t kKPack = GetSmemKPackQ(); + + constexpr auto q_lds_block_desc = [&]() { + if constexpr(Xor) + { +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType); + constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( + q_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + q_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( + q_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return q_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = + LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0; + + constexpr index_t kKPack = GetSmemKPackK(); + + constexpr auto k_lds_block_desc = [&]() { + if constexpr(Xor) + { +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType); + constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + k_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto v_lds_block_desc = [&]() { + if constexpr(Xor) + { + constexpr auto XorGroupSize = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + +#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType); + constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock; + + if constexpr(XorLengthFold > 1) + { + constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_tmp, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else +#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD + { + constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_xor_transform(make_tuple( + number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + }(); + + return v_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + GemmLoopOrder::MNK>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + GemmLoopOrder::KMN>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read N first, then K + // This is the same data consume order as BlockGEMM + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t NPerThread = kMaxVecLoad; + constexpr index_t NThreads = kNPerBlock / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read M first, then K + // This is the same data consume order as BlockGEMM + constexpr auto p_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode); + + return p_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + // Read N first, then K + // This is the same data consume order as BlockGEMM + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() + { + using SDataType = remove_cvref_t; + return static_cast(16 / sizeof(SDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPack = GetSmemNPackS(); + + constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto s_lds_block_desc = transform_tensor_descriptor( + s_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return s_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + // static_assert(MWarp == 1, "Check failed!"); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; + + // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm + constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = kKPerBlock / (K2 * K3); + constexpr index_t K0 = kTileK / kKPerBlock; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + constexpr auto s2_block_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<2, 2>>, + sequence<1, 2, 2, 2>, + sequence<0, 0, 1, 3>>{}; + + constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); + + return s2_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return MakeQLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS() + { + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + return NWarp > 1 ? MakeSLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::SaccDataType) + : 0; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // Alignment on gfx950 is 1280 Bytes + // Alignment before gfx950 is 512 Bytes. + return max(GetSmemSizeQ(), + GetSmemSizeK() + GetSmemSizeS() + GetSmemSizeV()); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 3489d6f9a1..e2cea97f9a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -383,23 +383,31 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return 16 / sizeof(VDataType); + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + + return kMaxVecLoad; } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; - using VDataType = remove_cvref_t; + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + if constexpr(std::is_same_v) { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) @@ -410,7 +418,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - using WarpGemm = typename Traits::WarpGemm; - using BlockGemmShape = typename Traits::BlockGemmShape; + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; + static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -86,17 +89,36 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; - return a_block_dstr_encode; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } } @@ -118,17 +140,33 @@ struct BlockGemmARegBRegCRegV1 } else { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + return b_block_dstr_encode; + } } } @@ -213,40 +251,82 @@ struct BlockGemmARegBRegCRegV1 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN) + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // read C warp tensor from C block tensor - using c_iter_idx = std:: - conditional_t, sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); }); }); - }); + } + else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK) + { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index fd5211a59a..d0be065fc9 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -13,7 +14,8 @@ template + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN, + index_t NumWaveGroups_ = 1> struct BlockGemmProblem { using ADataType = remove_cvref_t; @@ -21,8 +23,9 @@ struct BlockGemmProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t NumWaveGroups = NumWaveGroups_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index b18bf603a9..b3c86b9456 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -39,6 +39,12 @@ enum struct TailNumber Full, }; +enum struct GemmLoopOrder +{ + KMN, + MNK, +}; + } // namespace ck_tile inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..c628614b54 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,10 +14,11 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -45,9 +46,10 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; // In the base situation, the Preshuffle setting should be false. static constexpr bool Preshuffle = false; @@ -167,10 +169,11 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> using GemmPipelineProblem = GemmPipelineProblemBase; + VectorSizeB_, + BlockGemmLoopOrder_>; template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -224,8 +229,9 @@ struct UniversalGemmPipelineProblem static constexpr auto Scheduler = Scheduler_; static constexpr bool Preshuffle = Traits::Preshuffle; - static constexpr index_t VectorSizeA = VectorSizeA_; - static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index fb191d565d..d1deaf9e0e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -104,6 +104,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = 1>>; #endif +using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution = + WarpGemmImpl>>; + #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; #endif +using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution = + WarpGemmImpl>>; + #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp16 2:4 structural sparsity // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity @@ -74,6 +76,8 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; }; // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 434be9f84a..7a10d1fa56 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -14,10 +14,14 @@ namespace ck_tile { * Y dim must have at least one dim not been reduced */ // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, const ReduceFunc& reduce_func, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, // reduction sweep forward static_for<0, nstage, 1>{}([&](auto istage) { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); + if constexpr(CrossWarp) + { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); - // pull data from remote lane - const auto v_remote = warp_shuffle_down(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_down(v_local, lid_delta); - // reduce - v_local = reduce_func(v_local, v_remote); + // reduce + v_local = reduce_func(v_local, v_remote); + } + else + { + // pull data from remote lane + const auto v_swapped_regs = warp_shuffle_down_pair(v_local); + // reduce + v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1)); + } }); } });