mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Add attn sink (#2892)
* enable attn sink Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update attn_sink script Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix some error Signed-off-by: JL-underdog <Jun.Lin@amd.com> * clang-format Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd_kernel'mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix ci error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix format error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_bwd_pipeline_default_policy.hpp * Update fmha_fwd_runner.hpp * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * update splitkv_pipline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update splitkv&pagedkv pipeline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * add sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update attn_sink result log Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update smoke_test_fwd_sink.sh Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test file Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test script Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp * use constexpr kHasSink for sink in fmha pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update by pre-commit Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fmha_fwd.py * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove causal mask setting logic from mask.hpp Removed the mask setting logic for causal masks. * fix ci error that some usage of lamada not support in c++17 Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update remod.py * add smoke sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update fmha_pagedkv_prefill.py * Update FmhaFwdPipeline parameters in fmha_fwd.py * update block_fmha_pipeline_qr_ks_vs_async_trload.hpp Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix c++17 unsupprot error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp * Fix formatting of sink_seq_end assignment * Fix indentation for sink_seq_end assignment * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp --------- Signed-off-by: JL-underdog <Jun.Lin@amd.com> Signed-off-by: LJ-underdog <Jun.Lin@amd.com> Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -66,7 +66,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
{F_skip},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -103,7 +104,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip},{F_sink}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
@@ -190,9 +191,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
|
||||
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -239,13 +240,14 @@ class FmhaFwdApiTrait:
|
||||
dvpad: str
|
||||
skip: str
|
||||
tr_load: str
|
||||
sink: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -345,6 +347,7 @@ class FmhaFwdPipeline:
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_trload: str # true/false
|
||||
F_sink: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -415,6 +418,10 @@ class FmhaFwdPipeline:
|
||||
n += "_trload"
|
||||
else:
|
||||
n += "_ntrload"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
return n
|
||||
|
||||
@@ -462,6 +469,7 @@ class FmhaFwdApiPool:
|
||||
F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
@@ -588,6 +596,7 @@ class FmhaFwdKernel:
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -630,6 +639,7 @@ class FmhaFwdKernel:
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
@@ -696,49 +706,51 @@ class KernelComponentFactoryGfx9:
|
||||
pipelines = []
|
||||
if dtype in ["fp32"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -757,13 +769,14 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
)
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
if (
|
||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||
@@ -772,8 +785,8 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
and dropout == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -811,23 +824,24 @@ class KernelComponentFactoryGfx12:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -934,6 +948,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -945,6 +960,7 @@ def get_fwd_blobs(
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -985,6 +1001,7 @@ def get_fwd_blobs(
|
||||
cond = dtype == "fp32"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# fp32 only, minimal set of parameters
|
||||
@@ -998,6 +1015,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
cond &= pipeline.F_mask == "s_no"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -74,7 +74,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_pagedkv},
|
||||
kHasUnevenSplits,
|
||||
kMergeNumHeadGroupsSeqLenQ,
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
@@ -118,7 +119,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
}} // anonymous namespace
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#pragma clang diagnostic push
|
||||
@@ -280,8 +281,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
@@ -333,6 +334,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
pagedkv: str
|
||||
sink: str # sink or not
|
||||
bn1comb: int # tile size along v head_dim of combine kernel
|
||||
|
||||
@property
|
||||
@@ -340,7 +342,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
|
||||
+ f"{self.dvpad}-{self.pagedkv}"
|
||||
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -426,6 +428,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
F_lse: str #
|
||||
F_squant: str #
|
||||
F_pagedkv: str # t/f
|
||||
F_sink: str # t/f
|
||||
F_mask: str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
@@ -486,6 +489,10 @@ class FmhaFwdSplitKVPipeline:
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
return n
|
||||
|
||||
|
||||
@@ -568,6 +575,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
@@ -668,6 +676,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
@@ -741,19 +750,23 @@ class KernelComponentFactoryBase:
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
|
||||
for logits, mask, bias, pagedkv, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -909,6 +922,7 @@ def get_fwd_splitkv_blobs(
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -918,6 +932,7 @@ def get_fwd_splitkv_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
@@ -1076,6 +1091,7 @@ def write_blobs(
|
||||
lse=kernel.F_pipeline.F_lse,
|
||||
squant=kernel.F_pipeline.F_squant,
|
||||
pagedkv=kernel.F_pipeline.F_pagedkv,
|
||||
sink=kernel.F_pipeline.F_sink,
|
||||
spad=kernel.F_pipeline.F_spad,
|
||||
skpad=kernel.F_pipeline.F_skpad,
|
||||
dpad=kernel.F_pipeline.F_dpad,
|
||||
|
||||
@@ -66,7 +66,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
||||
{F_pagedkv}, //pagedkv
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
{F_skip},
|
||||
{F_sink}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -101,7 +102,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
@@ -130,9 +131,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
|
||||
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -164,12 +165,13 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
sink: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -257,6 +259,7 @@ class FmhaFwdPipeline:
|
||||
F_squant: str #
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_sink: str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -321,6 +324,10 @@ class FmhaFwdPipeline:
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
return n
|
||||
|
||||
@@ -364,6 +371,7 @@ class FmhaFwdApiPool:
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
@@ -481,6 +489,7 @@ class FmhaFwdKernel:
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -527,6 +536,7 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
)
|
||||
|
||||
|
||||
@@ -540,22 +550,23 @@ class KernelComponentFactoryBase:
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(
|
||||
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t"],
|
||||
["f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
pass # TODO
|
||||
else:
|
||||
@@ -679,6 +690,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -688,6 +700,7 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
|
||||
Reference in New Issue
Block a user