Add attention sink support for FMHA FWD (#3368)

* Revert "Revert "Add attn sink (#2892)" (#3250)"

This reverts commit e3be392d13e6ee107d823af32aca2d3ff03ca69d.

* fix conflict

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Add F_sink parameter to FmhaFwdPipeline

* Update tile_fmha_traits.hpp

* Refactor pipeline creation in fmha_fwd.py

Updated the pipeline creation logic to include 'sink' parameter in product combinations and adjusted the FmhaFwdPipeline calls accordingly.

* Update fmha_fwd.py

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update CHANGELOG.md

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update CHANGELOG with new features and support

* Update fmha_fwd.hpp

* Update CHANGELOG.md

* Update smoke_test_fwd_sink.sh

* Update correct_test_fwd_sink.sh

* Update smoke_test_fwd_sink.sh

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

[ROCm/composable_kernel commit: f5573f56d9]
This commit is contained in:
Linjun-AMD
2025-12-15 12:21:59 +08:00
committed by GitHub
parent eeb78c46a4
commit 51886bf22b
26 changed files with 948 additions and 196 deletions

View File

@@ -65,7 +65,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
endif()
# generate a list of kernels, but not actually emit files at config sta

View File

@@ -76,7 +76,8 @@ using fmha_traits = ck_tile::TileFmhaTraits<{F_spad},
{F_dropout},
{F_qscale},
{F_occupancy},
{F_skip}>;
{F_skip},
{F_sink}>;
using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -113,7 +114,7 @@ using fmha_kernel = {F_kernel}<fmha_pipeline, fmha_epilogue>;
using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
template<>
float fmha_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
@@ -229,9 +230,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -278,13 +279,14 @@ class FmhaFwdApiTrait:
dvpad: str
skip: str
tr_load: str
sink: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)
@property
@@ -384,6 +386,7 @@ class FmhaFwdPipeline:
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_trload: str # true/false
F_sink: str # true/false
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -454,6 +457,10 @@ class FmhaFwdPipeline:
n += "_trload"
else:
n += "_ntrload"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -543,6 +550,7 @@ class FmhaFwdApiPool:
F_trload=BOOL_MAP[trait.tr_load],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
@@ -683,6 +691,7 @@ class FmhaFwdKernel:
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
)
@property
@@ -725,6 +734,7 @@ class FmhaFwdKernel:
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
tr_load=self.F_pipeline.F_trload,
sink=self.F_pipeline.F_sink,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@@ -957,52 +967,55 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
pipelines = []
if dtype in cls._DT_FP32:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["no", "pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass
@@ -1033,13 +1046,14 @@ class KernelComponentFactoryGfx950(
)
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
@@ -1048,15 +1062,15 @@ class KernelComponentFactoryGfx950(
and dropout == "f"
and skip == "f"
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for mask in ["no", "causal"]:
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines
@@ -1105,23 +1119,24 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
pipelines = []
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
return pipelines

View File

@@ -73,7 +73,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_pagedkv},
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>;
{F_occupancy},
{F_sink}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -117,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}} // anonymous namespace
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#pragma clang diagnostic push
@@ -279,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -332,6 +333,7 @@ class FmhaFwdSplitKVApiTrait:
dpad: str
dvpad: str
pagedkv: str
sink: str # sink or not
bn1comb: int # tile size along v head_dim of combine kernel
@property
@@ -339,7 +341,7 @@ class FmhaFwdSplitKVApiTrait:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
+ f"{self.dvpad}-{self.pagedkv}"
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
)
@property
@@ -425,6 +427,7 @@ class FmhaFwdSplitKVPipeline:
F_lse: str #
F_squant: str #
F_pagedkv: str # t/f
F_sink: str # t/f
F_mask: str # value from MASK_MAP
@property
@@ -485,6 +488,10 @@ class FmhaFwdSplitKVPipeline:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -567,6 +574,7 @@ class FmhaFwdSplitKVApiPool:
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
@@ -667,6 +675,7 @@ class FmhaFwdSplitKVKernel:
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
@@ -740,19 +749,23 @@ class KernelComponentFactoryBase:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
for logits, mask, bias, pagedkv, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
):
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
@@ -908,6 +921,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
@@ -917,6 +931,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
@@ -1075,6 +1090,7 @@ def write_blobs(
lse=kernel.F_pipeline.F_lse,
squant=kernel.F_pipeline.F_squant,
pagedkv=kernel.F_pipeline.F_pagedkv,
sink=kernel.F_pipeline.F_sink,
spad=kernel.F_pipeline.F_spad,
skpad=kernel.F_pipeline.F_skpad,
dpad=kernel.F_pipeline.F_dpad,

View File

@@ -65,7 +65,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip}>;
{F_skip},
{F_sink}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -100,7 +101,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
template<>
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
@@ -129,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -163,12 +164,13 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
skip: str
sink: str
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)
@property
@@ -256,6 +258,7 @@ class FmhaFwdPipeline:
F_squant: str #
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_sink: str # true/false
@property
def name(self) -> str:
@@ -320,6 +323,10 @@ class FmhaFwdPipeline:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n
@@ -363,6 +370,7 @@ class FmhaFwdApiPool:
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_sink=BOOL_MAP[trait.sink],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
@@ -480,6 +488,7 @@ class FmhaFwdKernel:
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -526,6 +535,7 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
sink=self.F_pipeline.F_sink,
)
@@ -539,22 +549,23 @@ class KernelComponentFactoryBase:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, skip in itertools.product(
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t"],
["f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
pass # TODO
else:
@@ -678,6 +689,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
@@ -687,6 +699,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_fwd) integration

View File

@@ -266,6 +266,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
@@ -352,6 +353,7 @@ struct fmha_fwd_pagedkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
};
@@ -442,6 +444,7 @@ struct fmha_fwd_splitkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
};
@@ -561,6 +564,7 @@ struct fmha_batch_prefill_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
float p_drop;
@@ -613,6 +617,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
@@ -663,6 +668,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -824,6 +830,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
}
@@ -869,6 +876,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -935,6 +943,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
else
@@ -982,6 +991,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -1142,6 +1152,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -1194,6 +1205,7 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -1228,7 +1240,8 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1254,6 +1267,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1280,7 +1294,8 @@ template <ck_tile::index_t HDim_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_pagedkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1305,6 +1320,7 @@ struct fmha_fwd_pagedkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1327,6 +1343,7 @@ template <ck_tile::index_t HDim_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasSink_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
@@ -1354,6 +1371,7 @@ struct fmha_fwd_splitkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1440,6 +1458,7 @@ struct fmha_fwd_traits
bool has_dropout;
quant_scale_enum qscale_type;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
@@ -1458,6 +1477,7 @@ struct fmha_fwd_pagedkv_traits
bool use_pagedkv = true;
bool do_fp8_static_quant = false;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
@@ -1477,6 +1497,7 @@ struct fmha_fwd_splitkv_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,

View File

@@ -879,6 +879,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_sink = mask.sink > 0 ? true : false;
traits.has_lse = lse;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
@@ -1042,6 +1043,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.sink_size = mask.sink;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
@@ -1645,7 +1647,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k));
}
else
{
@@ -1657,6 +1659,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
mask.sink,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
@@ -1666,6 +1669,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
mask.sink,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));

View File

@@ -25,6 +25,7 @@ struct mask_info
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
ck_tile::index_t sink;
void serialize(std::ostream& os) const
{
@@ -58,13 +59,14 @@ struct mask_info
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
ck_tile::index_t sink_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
left_size, right_size, sink_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
@@ -79,27 +81,54 @@ struct mask_info
{
throw std::invalid_argument("invalid mask value: " + str);
}
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
auto found_2 = v.find(',', found_1 + 1);
ck_tile::index_t v1 = 0;
ck_tile::index_t sink = 0;
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
v0, v1, sink, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "b")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
v0, v1, sink, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "g")
{
@@ -108,6 +137,7 @@ struct mask_info
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
tmp.sink = 0;
}
}
else
@@ -126,6 +156,7 @@ struct mask_info
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else if(str == "2" || str == "b")
{
@@ -134,6 +165,7 @@ struct mask_info
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else
{

View File

@@ -0,0 +1,77 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
TEST_SPLITKV=0
TEST_APPENDKV=0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while getopts ":sa" opt; do
case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done
run_fp16_bf16_tests() {
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi
for prec in "fp16"; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for batch in 1 4; do
for head in 1; do
for h_k in 1; do
for q_seq in 128 512 ; do
for kv_seq in 128 1024; do
for hdim in 32 64 128 256; do #256
for lse in 0 1 ; do
for bias in "e" ; do
for p_drop in 0.0 0.2; do # 0.0
for mask in "t:2,0,4" "b:1,0,2"; do
for num_splits in $NUM_SPLITS ; do
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16 -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ; done
}
set -x
run_fp16_bf16_tests
set +x

View File

@@ -39,6 +39,7 @@ function print_log_header(){
#run verification tests
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"

View File

@@ -0,0 +1,86 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
set -euo pipefail
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
set -x
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
# window_size[2,0], sink_size = 2
# x=1/y=3
# 1 * * * * * * * 1 * * * * * * *
# 1 1 * * * * * * 1 1 * * * * * *
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
# * 1 1 1 * * * * 1 1 1 1 * * * *
# * * 1 1 1 * * * 1 1 1 1 1 * * *
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
# * * * * * 1 1 1 1 1 * * * 1 1 1
# l=2/r=0(tl) l=2/r=0/s=2(tl)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
# x=4/y=1
# 1 1 1 1 * * * * 1 1 1 1 * * * *
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
# l=0/r=3(tl) l=0/r=3/s=2(tl)
# l=3/r=0(br) l=3/r=0/s=2(br)
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
# x=4/y=-1
# * * 1 1 * * * * 1 1 1 1 * * * *
# * * * 1 1 * * * 1 1 * 1 1 * * *
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
# * * * * * 1 1 * 1 1 * * * 1 1 *
# * * * * * * 1 1 1 1 * * * * 1 1
# l=1/r=0(br) l=1/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
# x=-1/y=5
# * * * * * * * * * * * *
# * * * * * * * * * * * *
# 1 * * * * * 1 * * * * *
# 1 1 * * * * 1 1 * * * *
# 1 1 1 * * * ----> 1 1 1 * * *
# * 1 1 1 * * 1 1 1 1 * *
# * * 1 1 1 * 1 1 1 1 1 *
# * * * 1 1 1 1 1 * 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
# x=-1/y=8
# * * * * * * * * * *
# * * * * * * * * * *
# 1 * * * * ----> 1 * * * *
# 1 1 * * * 1 1 * * *
# 1 1 1 * * 1 1 1 * *
# 1 1 1 1 * 1 1 1 1 *
# 1 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)