From 51886bf22b3b36d668df2fbeab2da2642105a529 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 15 Dec 2025 12:21:59 +0800 Subject: [PATCH] Add attention sink support for FMHA FWD (#3368) * Revert "Revert "Add attn sink (#2892)" (#3250)" This reverts commit e3be392d13e6ee107d823af32aca2d3ff03ca69d. * fix conflict Signed-off-by: Linjun-AMD * Add F_sink parameter to FmhaFwdPipeline * Update tile_fmha_traits.hpp * Refactor pipeline creation in fmha_fwd.py Updated the pipeline creation logic to include 'sink' parameter in product combinations and adjusted the FmhaFwdPipeline calls accordingly. * Update fmha_fwd.py * Update fmha_fwd.py * Update example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update CHANGELOG.md Signed-off-by: Linjun-AMD * Update CHANGELOG with new features and support * Update fmha_fwd.hpp * Update CHANGELOG.md * Update smoke_test_fwd_sink.sh * Update correct_test_fwd_sink.sh * Update smoke_test_fwd_sink.sh --------- Signed-off-by: Linjun-AMD Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: f5573f56d9d4981def16f575ddb14535b93bb9bb] --- CHANGELOG.md | 1 + example/ck_tile/01_fmha/CMakeLists.txt | 2 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 75 +++++--- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 42 +++-- .../codegen/ops/fmha_pagedkv_prefill.py | 33 +++- example/ck_tile/01_fmha/fmha_fwd.hpp | 25 ++- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 6 +- example/ck_tile/01_fmha/mask.hpp | 42 ++++- .../01_fmha/script/correct_test_fwd_sink.sh | 77 ++++++++ .../ck_tile/01_fmha/script/run_full_test.sh | 1 + .../01_fmha/script/smoke_test_fwd_sink.sh | 86 +++++++++ .../reference/reference_batched_masking.hpp | 2 +- .../ck_tile/ops/fmha/block/block_masking.hpp | 178 ++++++++++++++++-- include/ck_tile/ops/fmha/block/variants.hpp | 33 ++++ .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 7 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 17 +- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 14 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 10 +- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 94 ++++++--- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 86 ++++++--- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 92 +++++++-- .../pipeline/block_fmha_pipeline_problem.hpp | 3 + .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 99 ++++++++-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 102 +++++++--- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 +- 26 files changed, 948 insertions(+), 196 deletions(-) create mode 100644 example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh create mode 100755 example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 15fdb09f49..997fb8bb8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". +* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. ### Changed diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 6e7d69281d..9c81207361 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -65,7 +65,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS # there is no corresponding instance for parameters). if(BUILD_TESTING) # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill - list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*) endif() # generate a list of kernels, but not actually emit files at config sta diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index edc0e049c5..4d6900a802 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -76,7 +76,8 @@ using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, {F_dropout}, {F_qscale}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -113,7 +114,7 @@ using fmha_kernel = {F_kernel}; using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -229,9 +230,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; return fmha_fwd_(s, a); }} """ @@ -278,13 +279,14 @@ class FmhaFwdApiTrait: dvpad: str skip: str tr_load: str + sink: str constraint: CppConstraint @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -384,6 +386,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false + F_sink: str # true/false F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -454,6 +457,10 @@ class FmhaFwdPipeline: n += "_trload" else: n += "_ntrload" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -543,6 +550,7 @@ class FmhaFwdApiPool: F_trload=BOOL_MAP[trait.tr_load], F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], F_qscale=QSCALE_MAP[trait.qscale], + F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, @@ -683,6 +691,7 @@ class FmhaFwdKernel: F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), + F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -725,6 +734,7 @@ class FmhaFwdKernel: dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, + sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -957,52 +967,55 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): pipelines = [] if dtype in cls._DT_FP32: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for logits, qscale, mask, bias, sink in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"], + ["f", "t"], ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO pass @@ -1033,13 +1046,14 @@ class KernelComponentFactoryGfx950( ) if dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -1048,15 +1062,15 @@ class KernelComponentFactoryGfx950( and dropout == "f" and skip == "f" ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # qr_async_trload_v3 only supports hdim=hdim_v=128 for now if (hdim, hdim_v) == (128, 128): # qr_async_trload_v3 only supports (generic) causal mask for mask in ["no", "causal"]: pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip return pipelines @@ -1105,23 +1119,24 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): pipelines = [] if dtype in cls._DT_FP16_BF16: qscale = "no" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 342a71e0b0..9105900fc7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -73,7 +73,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_pagedkv}, kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}>; + {F_occupancy}, + {F_sink}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -117,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -279,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -332,6 +333,7 @@ class FmhaFwdSplitKVApiTrait: dpad: str dvpad: str pagedkv: str + sink: str # sink or not bn1comb: int # tile size along v head_dim of combine kernel @property @@ -339,7 +341,7 @@ class FmhaFwdSplitKVApiTrait: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" - + f"{self.dvpad}-{self.pagedkv}" + + f"{self.dvpad}-{self.pagedkv}-{self.sink}" ) @property @@ -425,6 +427,7 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f + F_sink: str # t/f F_mask: str # value from MASK_MAP @property @@ -485,6 +488,10 @@ class FmhaFwdSplitKVPipeline: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -567,6 +574,7 @@ class FmhaFwdSplitKVApiPool: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -667,6 +675,7 @@ class FmhaFwdSplitKVKernel: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -740,19 +749,23 @@ class KernelComponentFactoryBase: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] + for logits, mask, bias, pagedkv, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], ): - pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -908,6 +921,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -917,6 +931,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1075,6 +1090,7 @@ def write_blobs( lse=kernel.F_pipeline.F_lse, squant=kernel.F_pipeline.F_squant, pagedkv=kernel.F_pipeline.F_pagedkv, + sink=kernel.F_pipeline.F_sink, spad=kernel.F_pipeline.F_spad, skpad=kernel.F_pipeline.F_skpad, dpad=kernel.F_pipeline.F_dpad, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index e6eb893a2f..cdb43c3480 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -65,7 +65,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, {F_pagedkv}, //pagedkv {F_squant}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -100,7 +101,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -129,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -163,12 +164,13 @@ class FmhaFwdApiTrait: dpad: str dvpad: str skip: str + sink: str @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -256,6 +258,7 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false + F_sink: str # true/false @property def name(self) -> str: @@ -320,6 +323,10 @@ class FmhaFwdPipeline: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -363,6 +370,7 @@ class FmhaFwdApiPool: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -480,6 +488,7 @@ class FmhaFwdKernel: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -526,6 +535,7 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + sink=self.F_pipeline.F_sink, ) @@ -539,22 +549,23 @@ class KernelComponentFactoryBase: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv, skip in itertools.product( + for logits, mask, bias, pagedkv, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -678,6 +689,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -687,6 +699,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 002d0a1035..60ba334fc0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -266,6 +266,7 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; @@ -352,6 +353,7 @@ struct fmha_fwd_pagedkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; }; @@ -442,6 +444,7 @@ struct fmha_fwd_splitkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; }; @@ -561,6 +564,7 @@ struct fmha_batch_prefill_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; float p_drop; @@ -613,6 +617,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q, args.p_drop, @@ -663,6 +668,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -824,6 +830,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q); } @@ -869,6 +876,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -935,6 +943,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } else @@ -982,6 +991,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -1142,6 +1152,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -1194,6 +1205,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -1228,7 +1240,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1254,6 +1267,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1280,7 +1294,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1305,6 +1320,7 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1327,6 +1343,7 @@ template @@ -1440,6 +1458,7 @@ struct fmha_fwd_traits bool has_dropout; quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1458,6 +1477,7 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; @@ -1477,6 +1497,7 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index bca4c60bc6..536fcb0692 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -879,6 +879,7 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; + traits.has_sink = mask.sink > 0 ? true : false; traits.has_lse = lse; if constexpr(std::is_same_v>) @@ -1042,6 +1043,7 @@ fwd_result fmha_fwd_run(mode_enum mode, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = mask.sink; args.mask_type = static_cast(mask.type); if constexpr(std::is_same_v>) @@ -1645,7 +1647,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k)); } else { @@ -1657,6 +1659,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); @@ -1666,6 +1669,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 32157a2245..f85b811116 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -25,6 +25,7 @@ struct mask_info ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right + ck_tile::index_t sink; void serialize(std::ostream& os) const { @@ -58,13 +59,14 @@ struct mask_info ck_tile::index_t window_size = std::stoi(v); ck_tile::index_t left_size = -1; ck_tile::index_t right_size = 0; + ck_tile::index_t sink_size = 0; if(window_size > 0) { left_size = window_size / 2; right_size = window_size - 1 - left_size; } auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, t == "xt"); + left_size, right_size, sink_size, y_total, x_total, t == "xt"); tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; tmp.y = r.at(ck_tile::number<0>{}); @@ -79,27 +81,54 @@ struct mask_info { throw std::invalid_argument("invalid mask value: " + str); } - ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); - ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + auto found_2 = v.find(',', found_1 + 1); + ck_tile::index_t v1 = 0; + ck_tile::index_t sink = 0; + // ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation if(t == "t") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_top_left; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); + v0, v1, sink, y_total, x_total, true); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "b") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_bottom_right; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); + v0, v1, sink, y_total, x_total, false); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "g") { @@ -108,6 +137,7 @@ struct mask_info tmp.x = v1; tmp.left = v0; // TODO: don't use this? tmp.right = v1; + tmp.sink = 0; } } else @@ -126,6 +156,7 @@ struct mask_info tmp.x = 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else if(str == "2" || str == "b") { @@ -134,6 +165,7 @@ struct mask_info tmp.x = seqlen_k - seqlen_q + 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else { diff --git a/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh new file mode 100644 index 0000000000..de3bff25ed --- /dev/null +++ b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16"; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for vlayout in "r" "c" ; do + for batch in 1 4; do + for head in 1; do + for h_k in 1; do + for q_seq in 128 512 ; do + for kv_seq in 128 1024; do + for hdim in 32 64 128 256; do #256 + for lse in 0 1 ; do + for bias in "e" ; do + for p_drop in 0.0 0.2; do # 0.0 + for mask in "t:2,0,4" "b:1,0,2"; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done +} + + +set -x + +run_fp16_bf16_tests + +set +x diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 4fbde37cae..456c3986fa 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -39,6 +39,7 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh new file mode 100755 index 0000000000..664c825418 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# TODO: run this script from CK root or build directory +#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi +set -x + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 + +# window_size[2,0], sink_size = 2 + +# x=1/y=3 +# 1 * * * * * * * 1 * * * * * * * +# 1 1 * * * * * * 1 1 * * * * * * +# 1 1 1 * * * * * ----> 1 1 1 * * * * * +# * 1 1 1 * * * * 1 1 1 1 * * * * +# * * 1 1 1 * * * 1 1 1 1 1 * * * +# * * * 1 1 1 * * 1 1 * 1 1 1 * * +# * * * * 1 1 1 * 1 1 * * 1 1 1 * +# * * * * * 1 1 1 1 1 * * * 1 1 1 +# l=2/r=0(tl) l=2/r=0/s=2(tl) + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 + +# x=4/y=1 +# 1 1 1 1 * * * * 1 1 1 1 * * * * +# * 1 1 1 1 * * * 1 1 1 1 1 * * * +# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * +# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * +# * * * * 1 1 1 1 1 1 * * 1 1 1 1 +# l=0/r=3(tl) l=0/r=3/s=2(tl) +# l=3/r=0(br) l=3/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 + +# x=4/y=-1 +# * * 1 1 * * * * 1 1 1 1 * * * * +# * * * 1 1 * * * 1 1 * 1 1 * * * +# * * * * 1 1 * * ----> 1 1 * * 1 1 * * +# * * * * * 1 1 * 1 1 * * * 1 1 * +# * * * * * * 1 1 1 1 * * * * 1 1 +# l=1/r=0(br) l=1/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 + +# x=-1/y=5 + +# * * * * * * * * * * * * +# * * * * * * * * * * * * +# 1 * * * * * 1 * * * * * +# 1 1 * * * * 1 1 * * * * +# 1 1 1 * * * ----> 1 1 1 * * * +# * 1 1 1 * * 1 1 1 1 * * +# * * 1 1 1 * 1 1 1 1 1 * +# * * * 1 1 1 1 1 * 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 +# x=-1/y=8 +# * * * * * * * * * * +# * * * * * * * * * * +# 1 * * * * ----> 1 * * * * +# 1 1 * * * 1 1 * * * +# 1 1 1 * * 1 1 1 * * +# 1 1 1 1 * 1 1 1 1 * +# 1 1 1 1 1 1 1 1 1 1 +# 1 1 1 1 1 1 1 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index c2dd8abe23..a172a0013e 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor& c_b_m_n, cons { for(int m = 0; m < M; ++m) { - if(mask.IsOutOfBound(m, n)) + if(mask.IsOutOfSinkBound(m, n)) c_b_m_n(batch, m, n) = -ck_tile::numeric::infinity(); } } diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 756968871d..4ffb303812 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -86,21 +86,22 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) - : GenericAttentionMask(0, 0, y_total_, x_total_) + : GenericAttentionMask(0, 0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) + sink(mask_coord.at(number<2>{})), + y_total(mask_coord.at(number<3>{})), + x_total(mask_coord.at(number<4>{})) { } @@ -141,6 +142,44 @@ struct GenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto + GetSinkTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, 0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; + if(x_start <= sink_seq_end && sink > 0) + return ck_tile::make_tuple(0, 0, x_end); + else + return ck_tile::make_tuple(sink_seq_end, x_start, x_end); + } + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -195,6 +234,30 @@ struct GenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + return i_x >= x_total; + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = min(i_y + x, x_total); + + if constexpr(IsLocal) + { + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x < x_start || i_x >= x_end; + } + else + { + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x >= x_end || i_y >= y_total; + } + } + // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() @@ -237,7 +300,7 @@ struct GenericAttentionMask } private: - index_t y, x; + index_t y, x, sink; index_t y_total, x_total; }; @@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask static constexpr const char* name = impl::SimplifiedMaskName::name; CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) - : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + : SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + SimplifiedGenericAttentionMask( + index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) + sink(mask_coord.at(number<2>{})), + y_total(mask_coord.at(number<3>{})), + x_total(mask_coord.at(number<4>{})) { } @@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto + GetSinkTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, 0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; + + if(x_start <= sink_seq_end && sink > 0) + return ck_tile::make_tuple(0, 0, x_end); + else + return ck_tile::make_tuple(sink_seq_end, x_start, x_end); + } + } + template CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number height, @@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask ck_tile::min(origin_end, split_end)); } + template + CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; // 128 + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256 + const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0; + const index_t start = ck_tile::max(origin_start, split_start); + const index_t end = ck_tile::min(origin_end, split_end); + const bool is_first_intersecting_split = + (split_start <= origin_start && split_end >= origin_start); + const bool sink_in_range = (sink_seq_end <= start); + + const index_t sink_offset = + (is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0; + return ck_tile::make_tuple(sink_offset, start, end); + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask { index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; } } + CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + return i_x >= x_total; + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() @@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask } private: - index_t y, x; + index_t y, x, sink; index_t y_total, x_total; }; @@ -620,6 +751,7 @@ static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t sink_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, sink_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total}; } template @@ -649,7 +795,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size, bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; + left_size, right_size, 0, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index 29d9cf2a8e..b6f79873b4 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -162,6 +162,17 @@ struct StandardAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; template @@ -224,6 +235,17 @@ struct LogitsSoftCap { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; constexpr uint32_t CUSTOM_MASK = 1U; @@ -297,6 +319,17 @@ struct ComposedAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index e63ad8252b..10b5d0120e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -200,7 +200,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -356,6 +356,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -418,6 +419,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -497,6 +499,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -557,6 +560,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -1008,6 +1012,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9890d1f2e4..9160e79af6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -58,6 +58,7 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -155,7 +156,7 @@ struct FmhaFwdKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -335,6 +336,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -393,6 +395,7 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -481,6 +484,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -529,6 +533,7 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type, p_drop, s_randval, @@ -580,6 +585,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -628,6 +634,7 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type, p_drop, s_randval, @@ -673,6 +680,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -732,6 +740,7 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -817,6 +826,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -861,6 +871,7 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q, p_drop, @@ -908,6 +919,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -952,6 +964,7 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q, p_drop, @@ -1443,6 +1456,7 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); @@ -2182,6 +2196,7 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 677ead91ad..b75b35fc1e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); #undef _SS_ #undef _TS_ // clang-format on @@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { return MakeKargsImpl(q_ptr, @@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type); } @@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel batch_stride_v, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q); } @@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 19592e8bf4..bd5cddb526 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); #undef _SS_ #undef _TS_ // clang-format on @@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel struct MaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 693f81d08a..d55d0d9342 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS clear_tile(o_acc); set_tile(m, -numeric::infinity()); clear_tile(l); - - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return o_acc; } } - // k_dram_block_window const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; @@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); // v_dram_window auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); - auto q_tile = tile_elementwise_in(q_element_func, q); // prefetch K tile @@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {kN0, 0})); + i_page_block_k, k_dram_block_window, {k_move_offset, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); { const auto k_origin = k_page_block_navigator.to_global_window_origin( @@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if(s_acc, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto row, auto col) { + return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } } @@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); + physical_next_block_id_v = + amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 0b30077a29..944d49a8aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + else + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - // check early exit if no work to do + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { const index_t logical_num_total_loop = @@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // store Q into LDS @@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // STAGE 1, QK gemm clear_tile(s_acc); // initialize C - + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); // load the second tile of the first iteration k_block_tile = load_tile(k_dram_window); @@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); - }); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } @@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}); k_dram_window = make_tile_window( k_dram_block_window, @@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS // moving k_dram_window is an in-page-block operation, so there is // no need to invoke k_page_block_navigator.move_tile_window() here. move_tile_window(k_dram_window, {0, kK0}); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); } } while(++i_total_loops < num_total_loop); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 6be6a64b1c..26a4cc905c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + else + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) @@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); + auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); + auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {kN0, 0})); + i_page_block_k, k_dram_block_window, {k_move_offset, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); - }); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } @@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); + physical_next_block_id_v = + amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 7c4a921b70..a192e3f7b0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kHasSink = Traits::kHasSink; }; template {}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == 0) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask(std::forward(args)...); }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } } } @@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS if constexpr(kHasDropout) { - // K and dropout use the same address in LDS, finish loading from k_lds_window by - // gemm_0 to reuse LDS. block_sync_lds(); + auto randval_ptr = reinterpret_cast(smem_ptr); + + index_t seq_offset = [&]() { + if constexpr(!kHasSink) + return seqlen_k_start + i_total_loops * kN0; + + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); + + return in_sink_phase ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + }(); + dropout.template Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } block_sync_lds(); @@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS }); } // move K tile windows + if constexpr(kHasSink) + { + if(i_total_loops == 0) + { + move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); + move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); + } + } move_tile_window(k_dram_block_window, {kN0, 0}); // tail { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e07516cc27..f57b89cf9d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync clear_tile(l); __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // prefetch K tile @@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == 0) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask(std::forward(args)...); }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } } } @@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + index_t seq_offset = [&]() { + if constexpr(!kHasSink) + return seqlen_k_start + i_total_loops * kN0; + + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); + + return in_sink_phase ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + }(); + dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } const auto p = [&]() { @@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - // move K tile windows + if constexpr(kHasSink) + { + if(i_total_loops == 0) + { + move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); + move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); + } + } move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 5d224a6adf..26662dafeb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index df33a93696..757a852c19 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -20,8 +20,9 @@ template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + bool kHasSink_ = false> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -36,6 +37,7 @@ struct TileFmhaTraits static constexpr auto QScaleEnum = QScaleEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template 1 or fwd training is running */ bool kIsPagedKV_, bool kDoFp8StaticQuant_, - index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ - bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + bool kHasSink_ = false> struct TileFmhaFwdPagedKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -81,6 +84,7 @@ struct TileFmhaFwdPagedKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kHasSink_ = false> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -112,6 +117,7 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kHasSink = kHasSink_; }; template