mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
This reverts commit bbe1d3a917ee92655224c0f1528ace3a7b0e82a8.
[ROCm/composable_kernel commit: 5adaa201ed]
This commit is contained in:
@@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
|||||||
# there is no corresponding instance for parameters).
|
# there is no corresponding instance for parameters).
|
||||||
if(BUILD_TESTING)
|
if(BUILD_TESTING)
|
||||||
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
|
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
|
||||||
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
|
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# generate a list of kernels, but not actually emit files at config sta
|
# generate a list of kernels, but not actually emit files at config sta
|
||||||
|
|||||||
@@ -66,8 +66,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
|||||||
{F_dropout},
|
{F_dropout},
|
||||||
{F_squant},
|
{F_squant},
|
||||||
{F_occupancy},
|
{F_occupancy},
|
||||||
{F_skip},
|
{F_skip}>;
|
||||||
{F_sink}>;
|
|
||||||
|
|
||||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||||
|
|
||||||
@@ -104,7 +103,7 @@ using fmha_kernel_{F_idx} =
|
|||||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||||
|
|
||||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip},{F_sink}>;
|
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||||
@@ -191,9 +190,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
|
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
|
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||||
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
@@ -240,14 +239,13 @@ class FmhaFwdApiTrait:
|
|||||||
dvpad: str
|
dvpad: str
|
||||||
skip: str
|
skip: str
|
||||||
tr_load: str
|
tr_load: str
|
||||||
sink: str
|
|
||||||
constraint: CppConstraint
|
constraint: CppConstraint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -347,7 +345,6 @@ class FmhaFwdPipeline:
|
|||||||
F_mask: str # value from MASK_MAP
|
F_mask: str # value from MASK_MAP
|
||||||
F_skip: str # true/false
|
F_skip: str # true/false
|
||||||
F_trload: str # true/false
|
F_trload: str # true/false
|
||||||
F_sink: str # true/false
|
|
||||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -418,10 +415,6 @@ class FmhaFwdPipeline:
|
|||||||
n += "_trload"
|
n += "_trload"
|
||||||
else:
|
else:
|
||||||
n += "_ntrload"
|
n += "_ntrload"
|
||||||
if self.F_sink == "t":
|
|
||||||
n += "_sink"
|
|
||||||
else:
|
|
||||||
n += "_nsink"
|
|
||||||
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
@@ -469,7 +462,6 @@ class FmhaFwdApiPool:
|
|||||||
F_dropout=BOOL_MAP[trait.dropout],
|
F_dropout=BOOL_MAP[trait.dropout],
|
||||||
F_skip=BOOL_MAP[trait.skip],
|
F_skip=BOOL_MAP[trait.skip],
|
||||||
F_trload=BOOL_MAP[trait.tr_load],
|
F_trload=BOOL_MAP[trait.tr_load],
|
||||||
F_sink=BOOL_MAP[trait.sink],
|
|
||||||
F_squant=BOOL_MAP[trait.squant],
|
F_squant=BOOL_MAP[trait.squant],
|
||||||
F_scheck=trait.scheck,
|
F_scheck=trait.scheck,
|
||||||
F_seqtune=trait.seqtune(max_bm0),
|
F_seqtune=trait.seqtune(max_bm0),
|
||||||
@@ -596,7 +588,6 @@ class FmhaFwdKernel:
|
|||||||
F_mode=MODE_MAP[self.F_mode],
|
F_mode=MODE_MAP[self.F_mode],
|
||||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||||
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
|
||||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -639,7 +630,6 @@ class FmhaFwdKernel:
|
|||||||
dvpad=self.F_pipeline.F_dvpad,
|
dvpad=self.F_pipeline.F_dvpad,
|
||||||
skip=self.F_pipeline.F_skip,
|
skip=self.F_pipeline.F_skip,
|
||||||
tr_load=self.F_pipeline.F_trload,
|
tr_load=self.F_pipeline.F_trload,
|
||||||
sink=self.F_pipeline.F_sink,
|
|
||||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -706,51 +696,49 @@ class KernelComponentFactoryGfx9:
|
|||||||
pipelines = []
|
pipelines = []
|
||||||
if dtype in ["fp32"]:
|
if dtype in ["fp32"]:
|
||||||
squant = "f"
|
squant = "f"
|
||||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
get_mask_map(mask_impl).keys(),
|
get_mask_map(mask_impl).keys(),
|
||||||
BIAS_MAP.keys(),
|
BIAS_MAP.keys(),
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
elif dtype in ["fp16", "bf16"]:
|
elif dtype in ["fp16", "bf16"]:
|
||||||
squant = "f"
|
squant = "f"
|
||||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
get_mask_map(mask_impl).keys(),
|
get_mask_map(mask_impl).keys(),
|
||||||
BIAS_MAP.keys(),
|
BIAS_MAP.keys(),
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
if hdim == 256 and hdim_v == 256:
|
if hdim == 256 and hdim_v == 256:
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
# the below two is used for hdim vectorize load
|
# the below two is used for hdim vectorize load
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
else:
|
else:
|
||||||
if bias == "bias":
|
if bias == "bias":
|
||||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
else:
|
else:
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
if receipt == 1 and bias != "bias":
|
if receipt == 1 and bias != "bias":
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||||
# no need lse/dropout kernels
|
# no need lse/dropout kernels
|
||||||
for logits, squant, mask, bias in itertools.product(
|
for logits, squant, mask, bias in itertools.product(
|
||||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||||
elif dtype in ["fp8fp16", "bf8"]:
|
elif dtype in ["fp8fp16", "bf8"]:
|
||||||
# TODO
|
# TODO
|
||||||
None
|
None
|
||||||
@@ -769,14 +757,13 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
|||||||
)
|
)
|
||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
squant = "f"
|
squant = "f"
|
||||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
get_mask_map(mask_impl).keys(),
|
get_mask_map(mask_impl).keys(),
|
||||||
BIAS_MAP.keys(),
|
BIAS_MAP.keys(),
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||||
@@ -785,8 +772,8 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
|||||||
and dropout == "f"
|
and dropout == "f"
|
||||||
and skip == "f"
|
and skip == "f"
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||||
return pipelines
|
return pipelines
|
||||||
|
|
||||||
|
|
||||||
@@ -824,24 +811,23 @@ class KernelComponentFactoryGfx12:
|
|||||||
pipelines = []
|
pipelines = []
|
||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
squant = "f"
|
squant = "f"
|
||||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
get_mask_map(mask_impl).keys(),
|
get_mask_map(mask_impl).keys(),
|
||||||
BIAS_MAP.keys(),
|
BIAS_MAP.keys(),
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||||
# no need lse/dropout kernels
|
# no need lse/dropout kernels
|
||||||
for logits, squant, mask, bias in itertools.product(
|
for logits, squant, mask, bias in itertools.product(
|
||||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
return pipelines
|
return pipelines
|
||||||
@@ -948,7 +934,6 @@ def get_fwd_blobs(
|
|||||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||||
cond &= pipeline.F_squant == "f"
|
cond &= pipeline.F_squant == "f"
|
||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# PyTorch integration
|
# PyTorch integration
|
||||||
@@ -960,7 +945,6 @@ def get_fwd_blobs(
|
|||||||
cond &= mode == "batch"
|
cond &= mode == "batch"
|
||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_logits == "f"
|
cond &= pipeline.F_logits == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# Aiter(mha_fwd) integration
|
# Aiter(mha_fwd) integration
|
||||||
@@ -1001,7 +985,6 @@ def get_fwd_blobs(
|
|||||||
cond = dtype == "fp32"
|
cond = dtype == "fp32"
|
||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_logits == "f"
|
cond &= pipeline.F_logits == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# fp32 only, minimal set of parameters
|
# fp32 only, minimal set of parameters
|
||||||
@@ -1015,7 +998,6 @@ def get_fwd_blobs(
|
|||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_logits == "f"
|
cond &= pipeline.F_logits == "f"
|
||||||
cond &= pipeline.F_mask == "s_no"
|
cond &= pipeline.F_mask == "s_no"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
|||||||
{F_pagedkv},
|
{F_pagedkv},
|
||||||
kHasUnevenSplits,
|
kHasUnevenSplits,
|
||||||
kMergeNumHeadGroupsSeqLenQ,
|
kMergeNumHeadGroupsSeqLenQ,
|
||||||
{F_occupancy},
|
{F_occupancy}>;
|
||||||
{F_sink}>;
|
|
||||||
|
|
||||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||||
@@ -119,7 +118,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
|||||||
}} // anonymous namespace
|
}} // anonymous namespace
|
||||||
|
|
||||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
|
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||||
{F_dvpad}>;
|
{F_dvpad}>;
|
||||||
|
|
||||||
#pragma clang diagnostic push
|
#pragma clang diagnostic push
|
||||||
@@ -281,8 +280,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||||
|
|
||||||
// get combine kernel tile sizes
|
// get combine kernel tile sizes
|
||||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||||
@@ -334,7 +333,6 @@ class FmhaFwdSplitKVApiTrait:
|
|||||||
dpad: str
|
dpad: str
|
||||||
dvpad: str
|
dvpad: str
|
||||||
pagedkv: str
|
pagedkv: str
|
||||||
sink: str # sink or not
|
|
||||||
bn1comb: int # tile size along v head_dim of combine kernel
|
bn1comb: int # tile size along v head_dim of combine kernel
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -342,7 +340,7 @@ class FmhaFwdSplitKVApiTrait:
|
|||||||
return (
|
return (
|
||||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
|
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
|
||||||
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
|
+ f"{self.dvpad}-{self.pagedkv}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -428,7 +426,6 @@ class FmhaFwdSplitKVPipeline:
|
|||||||
F_lse: str #
|
F_lse: str #
|
||||||
F_squant: str #
|
F_squant: str #
|
||||||
F_pagedkv: str # t/f
|
F_pagedkv: str # t/f
|
||||||
F_sink: str # t/f
|
|
||||||
F_mask: str # value from MASK_MAP
|
F_mask: str # value from MASK_MAP
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -489,10 +486,6 @@ class FmhaFwdSplitKVPipeline:
|
|||||||
n += "_pagedkv"
|
n += "_pagedkv"
|
||||||
else:
|
else:
|
||||||
n += "_npagedkv"
|
n += "_npagedkv"
|
||||||
if self.F_sink == "t":
|
|
||||||
n += "_sink"
|
|
||||||
else:
|
|
||||||
n += "_nsink"
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
@@ -575,7 +568,6 @@ class FmhaFwdSplitKVApiPool:
|
|||||||
F_lse=BOOL_MAP[trait.lse],
|
F_lse=BOOL_MAP[trait.lse],
|
||||||
F_squant=BOOL_MAP[trait.squant],
|
F_squant=BOOL_MAP[trait.squant],
|
||||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||||
F_sink=BOOL_MAP[trait.sink],
|
|
||||||
F_scheck=trait.scheck,
|
F_scheck=trait.scheck,
|
||||||
F_skcheck=trait.skcheck,
|
F_skcheck=trait.skcheck,
|
||||||
F_dcheck=trait.dcheck,
|
F_dcheck=trait.dcheck,
|
||||||
@@ -676,7 +668,6 @@ class FmhaFwdSplitKVKernel:
|
|||||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||||
F_occupancy=self.F_tile.F_occupancy,
|
F_occupancy=self.F_tile.F_occupancy,
|
||||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
|
||||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||||
F_mode=MODE_MAP[self.F_mode],
|
F_mode=MODE_MAP[self.F_mode],
|
||||||
@@ -750,23 +741,19 @@ class KernelComponentFactoryBase:
|
|||||||
squant = "t" if dtype == "fp8" else "f"
|
squant = "t" if dtype == "fp8" else "f"
|
||||||
pipelines = []
|
pipelines = []
|
||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
for logits, mask, bias, pagedkv, sink in itertools.product(
|
for logits, mask, bias, pagedkv in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
|
||||||
get_mask_map(mask_impl).keys(),
|
|
||||||
BIAS_MAP.keys(),
|
|
||||||
["t", "f"],
|
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||||
elif dtype in ["fp8", "bf8"]:
|
elif dtype in ["fp8", "bf8"]:
|
||||||
for logits, mask, bias in itertools.product(
|
for logits, mask, bias in itertools.product(
|
||||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||||
):
|
):
|
||||||
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
|
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||||
# TODO
|
# TODO
|
||||||
None
|
None
|
||||||
@@ -922,7 +909,6 @@ def get_fwd_splitkv_blobs(
|
|||||||
cond &= pipeline.F_vlayout == "row"
|
cond &= pipeline.F_vlayout == "row"
|
||||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||||
cond &= pipeline.F_squant == "f"
|
cond &= pipeline.F_squant == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# PyTorch integration
|
# PyTorch integration
|
||||||
@@ -932,7 +918,6 @@ def get_fwd_splitkv_blobs(
|
|||||||
cond &= pipeline.F_bias in ["no", "bias"]
|
cond &= pipeline.F_bias in ["no", "bias"]
|
||||||
cond &= pipeline.F_squant == "f"
|
cond &= pipeline.F_squant == "f"
|
||||||
cond &= mode == "batch"
|
cond &= mode == "batch"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# Aiter(mha_varlen_fwd) integration
|
# Aiter(mha_varlen_fwd) integration
|
||||||
@@ -1091,7 +1076,6 @@ def write_blobs(
|
|||||||
lse=kernel.F_pipeline.F_lse,
|
lse=kernel.F_pipeline.F_lse,
|
||||||
squant=kernel.F_pipeline.F_squant,
|
squant=kernel.F_pipeline.F_squant,
|
||||||
pagedkv=kernel.F_pipeline.F_pagedkv,
|
pagedkv=kernel.F_pipeline.F_pagedkv,
|
||||||
sink=kernel.F_pipeline.F_sink,
|
|
||||||
spad=kernel.F_pipeline.F_spad,
|
spad=kernel.F_pipeline.F_spad,
|
||||||
skpad=kernel.F_pipeline.F_skpad,
|
skpad=kernel.F_pipeline.F_skpad,
|
||||||
dpad=kernel.F_pipeline.F_dpad,
|
dpad=kernel.F_pipeline.F_dpad,
|
||||||
|
|||||||
@@ -66,8 +66,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
|||||||
{F_pagedkv}, //pagedkv
|
{F_pagedkv}, //pagedkv
|
||||||
{F_squant},
|
{F_squant},
|
||||||
{F_occupancy},
|
{F_occupancy},
|
||||||
{F_skip},
|
{F_skip}>;
|
||||||
{F_sink}>;
|
|
||||||
|
|
||||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||||
|
|
||||||
@@ -102,7 +101,7 @@ using fmha_kernel_{F_idx} =
|
|||||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||||
|
|
||||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
|
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||||
@@ -131,9 +130,9 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
|
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
|
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||||
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
@@ -165,13 +164,12 @@ class FmhaFwdApiTrait:
|
|||||||
dpad: str
|
dpad: str
|
||||||
dvpad: str
|
dvpad: str
|
||||||
skip: str
|
skip: str
|
||||||
sink: str
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
|
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -259,7 +257,6 @@ class FmhaFwdPipeline:
|
|||||||
F_squant: str #
|
F_squant: str #
|
||||||
F_mask: str # value from MASK_MAP
|
F_mask: str # value from MASK_MAP
|
||||||
F_skip: str # true/false
|
F_skip: str # true/false
|
||||||
F_sink: str # true/false
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -324,10 +321,6 @@ class FmhaFwdPipeline:
|
|||||||
n += "_pagedkv"
|
n += "_pagedkv"
|
||||||
else:
|
else:
|
||||||
n += "_npagedkv"
|
n += "_npagedkv"
|
||||||
if self.F_sink == "t":
|
|
||||||
n += "_sink"
|
|
||||||
else:
|
|
||||||
n += "_nsink"
|
|
||||||
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
@@ -371,7 +364,6 @@ class FmhaFwdApiPool:
|
|||||||
F_lse=BOOL_MAP[trait.lse],
|
F_lse=BOOL_MAP[trait.lse],
|
||||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||||
F_skip=BOOL_MAP[trait.skip],
|
F_skip=BOOL_MAP[trait.skip],
|
||||||
F_sink=BOOL_MAP[trait.sink],
|
|
||||||
F_squant=BOOL_MAP[trait.squant],
|
F_squant=BOOL_MAP[trait.squant],
|
||||||
F_scheck=trait.scheck,
|
F_scheck=trait.scheck,
|
||||||
F_skcheck=trait.skcheck,
|
F_skcheck=trait.skcheck,
|
||||||
@@ -489,7 +481,6 @@ class FmhaFwdKernel:
|
|||||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
|
||||||
F_occupancy=self.F_tile.F_occupancy,
|
F_occupancy=self.F_tile.F_occupancy,
|
||||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||||
@@ -536,7 +527,6 @@ class FmhaFwdKernel:
|
|||||||
dpad=self.F_pipeline.F_dpad,
|
dpad=self.F_pipeline.F_dpad,
|
||||||
dvpad=self.F_pipeline.F_dvpad,
|
dvpad=self.F_pipeline.F_dvpad,
|
||||||
skip=self.F_pipeline.F_skip,
|
skip=self.F_pipeline.F_skip,
|
||||||
sink=self.F_pipeline.F_sink,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -550,23 +540,22 @@ class KernelComponentFactoryBase:
|
|||||||
squant = "t" if dtype == "fp8" else "f"
|
squant = "t" if dtype == "fp8" else "f"
|
||||||
pipelines = []
|
pipelines = []
|
||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
|
for logits, mask, bias, pagedkv, skip in itertools.product(
|
||||||
["t", "f"],
|
["t", "f"],
|
||||||
get_mask_map(mask_impl).keys(),
|
get_mask_map(mask_impl).keys(),
|
||||||
BIAS_MAP.keys(),
|
BIAS_MAP.keys(),
|
||||||
["t"],
|
["t"],
|
||||||
["f"],
|
["f"],
|
||||||
["t", "f"],
|
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||||
elif dtype in ["fp8", "bf8"]:
|
elif dtype in ["fp8", "bf8"]:
|
||||||
# no need lse/dropout kernels
|
# no need lse/dropout kernels
|
||||||
for logits, mask, bias in itertools.product(
|
for logits, mask, bias in itertools.product(
|
||||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||||
):
|
):
|
||||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||||
pass # TODO
|
pass # TODO
|
||||||
else:
|
else:
|
||||||
@@ -690,7 +679,6 @@ def get_fwd_blobs(
|
|||||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||||
cond &= pipeline.F_squant == "f"
|
cond &= pipeline.F_squant == "f"
|
||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# PyTorch integration
|
# PyTorch integration
|
||||||
@@ -700,7 +688,6 @@ def get_fwd_blobs(
|
|||||||
cond &= pipeline.F_bias in ["no", "bias"]
|
cond &= pipeline.F_bias in ["no", "bias"]
|
||||||
cond &= pipeline.F_squant == "f"
|
cond &= pipeline.F_squant == "f"
|
||||||
cond &= pipeline.F_skip == "f"
|
cond &= pipeline.F_skip == "f"
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
# Aiter(mha_fwd) integration
|
# Aiter(mha_fwd) integration
|
||||||
|
|||||||
@@ -265,7 +265,6 @@ struct fmha_fwd_args
|
|||||||
|
|
||||||
ck_tile::index_t window_size_left;
|
ck_tile::index_t window_size_left;
|
||||||
ck_tile::index_t window_size_right;
|
ck_tile::index_t window_size_right;
|
||||||
ck_tile::index_t sink_size;
|
|
||||||
ck_tile::index_t mask_type;
|
ck_tile::index_t mask_type;
|
||||||
ck_tile::index_t min_seqlen_q;
|
ck_tile::index_t min_seqlen_q;
|
||||||
|
|
||||||
@@ -352,7 +351,6 @@ struct fmha_fwd_pagedkv_args
|
|||||||
|
|
||||||
ck_tile::index_t window_size_left;
|
ck_tile::index_t window_size_left;
|
||||||
ck_tile::index_t window_size_right;
|
ck_tile::index_t window_size_right;
|
||||||
ck_tile::index_t sink_size;
|
|
||||||
ck_tile::index_t mask_type;
|
ck_tile::index_t mask_type;
|
||||||
ck_tile::index_t min_seqlen_q;
|
ck_tile::index_t min_seqlen_q;
|
||||||
};
|
};
|
||||||
@@ -443,7 +441,6 @@ struct fmha_fwd_splitkv_args
|
|||||||
|
|
||||||
ck_tile::index_t window_size_left;
|
ck_tile::index_t window_size_left;
|
||||||
ck_tile::index_t window_size_right;
|
ck_tile::index_t window_size_right;
|
||||||
ck_tile::index_t sink_size;
|
|
||||||
ck_tile::index_t mask_type;
|
ck_tile::index_t mask_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -614,7 +611,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
|||||||
args.nhead_stride_o,
|
args.nhead_stride_o,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type,
|
args.mask_type,
|
||||||
args.min_seqlen_q,
|
args.min_seqlen_q,
|
||||||
args.p_drop,
|
args.p_drop,
|
||||||
@@ -664,7 +660,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
|||||||
args.batch_stride_o,
|
args.batch_stride_o,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type,
|
args.mask_type,
|
||||||
args.p_drop,
|
args.p_drop,
|
||||||
args.s_randval,
|
args.s_randval,
|
||||||
@@ -732,7 +727,6 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
|||||||
args.batch_stride_v,
|
args.batch_stride_v,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type,
|
args.mask_type,
|
||||||
args.min_seqlen_q);
|
args.min_seqlen_q);
|
||||||
}
|
}
|
||||||
@@ -778,7 +772,6 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
|||||||
args.batch_stride_o,
|
args.batch_stride_o,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type);
|
args.mask_type);
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
@@ -845,7 +838,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
|||||||
args.split_stride_o_acc,
|
args.split_stride_o_acc,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type);
|
args.mask_type);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -893,7 +885,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
|||||||
args.split_stride_o_acc,
|
args.split_stride_o_acc,
|
||||||
args.window_size_left,
|
args.window_size_left,
|
||||||
args.window_size_right,
|
args.window_size_right,
|
||||||
args.sink_size,
|
|
||||||
args.mask_type);
|
args.mask_type);
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
@@ -1140,8 +1131,7 @@ template <ck_tile::index_t HDim_,
|
|||||||
bool kPadD_,
|
bool kPadD_,
|
||||||
bool kPadDv_,
|
bool kPadDv_,
|
||||||
bool kUseTrLoad_,
|
bool kUseTrLoad_,
|
||||||
bool kSkipMinSeqlenQ_ = false,
|
bool kSkipMinSeqlenQ_ = false>
|
||||||
bool kHasSink_ = false>
|
|
||||||
struct fmha_fwd_traits_
|
struct fmha_fwd_traits_
|
||||||
{
|
{
|
||||||
static constexpr ck_tile::index_t HDim = HDim_;
|
static constexpr ck_tile::index_t HDim = HDim_;
|
||||||
@@ -1167,7 +1157,6 @@ struct fmha_fwd_traits_
|
|||||||
static constexpr bool kPadDv = kPadDv_;
|
static constexpr bool kPadDv = kPadDv_;
|
||||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Traits_, typename Arch = void>
|
template <typename Traits_, typename Arch = void>
|
||||||
@@ -1194,8 +1183,7 @@ template <ck_tile::index_t HDim_,
|
|||||||
bool kPadSK_,
|
bool kPadSK_,
|
||||||
bool kPadD_,
|
bool kPadD_,
|
||||||
bool kPadDv_,
|
bool kPadDv_,
|
||||||
bool kSkipMinSeqlenQ_ = false,
|
bool kSkipMinSeqlenQ_ = false>
|
||||||
bool kHasSink_ = false>
|
|
||||||
struct fmha_fwd_pagedkv_traits_
|
struct fmha_fwd_pagedkv_traits_
|
||||||
{
|
{
|
||||||
static constexpr ck_tile::index_t HDim = HDim_;
|
static constexpr ck_tile::index_t HDim = HDim_;
|
||||||
@@ -1220,7 +1208,6 @@ struct fmha_fwd_pagedkv_traits_
|
|||||||
static constexpr bool kPadD = kPadD_;
|
static constexpr bool kPadD = kPadD_;
|
||||||
static constexpr bool kPadDv = kPadDv_;
|
static constexpr bool kPadDv = kPadDv_;
|
||||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Traits_, typename Arch = void>
|
template <typename Traits_, typename Arch = void>
|
||||||
@@ -1243,7 +1230,6 @@ template <ck_tile::index_t HDim_,
|
|||||||
bool kStoreLse_,
|
bool kStoreLse_,
|
||||||
bool kDoFp8StaticQuant_,
|
bool kDoFp8StaticQuant_,
|
||||||
bool kIsPagedKV_,
|
bool kIsPagedKV_,
|
||||||
bool kHasSink_,
|
|
||||||
bool kPadS_,
|
bool kPadS_,
|
||||||
bool kPadSK_,
|
bool kPadSK_,
|
||||||
bool kPadD_,
|
bool kPadD_,
|
||||||
@@ -1271,7 +1257,6 @@ struct fmha_fwd_splitkv_traits_
|
|||||||
static constexpr bool kPadD = kPadD_;
|
static constexpr bool kPadD = kPadD_;
|
||||||
static constexpr bool kPadDv = kPadDv_;
|
static constexpr bool kPadDv = kPadDv_;
|
||||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Traits_, typename Arch = void>
|
template <typename Traits_, typename Arch = void>
|
||||||
@@ -1358,7 +1343,6 @@ struct fmha_fwd_traits
|
|||||||
bool has_dropout;
|
bool has_dropout;
|
||||||
bool do_fp8_static_quant;
|
bool do_fp8_static_quant;
|
||||||
bool skip_min_seqlen_q = false;
|
bool skip_min_seqlen_q = false;
|
||||||
bool has_sink = false;
|
|
||||||
// TODO: padding check is inside this api
|
// TODO: padding check is inside this api
|
||||||
};
|
};
|
||||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||||
@@ -1377,7 +1361,6 @@ struct fmha_fwd_pagedkv_traits
|
|||||||
bool use_pagedkv = true;
|
bool use_pagedkv = true;
|
||||||
bool do_fp8_static_quant = false;
|
bool do_fp8_static_quant = false;
|
||||||
bool skip_min_seqlen_q = false;
|
bool skip_min_seqlen_q = false;
|
||||||
bool has_sink = false;
|
|
||||||
// TODO: padding check is inside this api
|
// TODO: padding check is inside this api
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1397,7 +1380,6 @@ struct fmha_fwd_splitkv_traits
|
|||||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||||
bool has_lse;
|
bool has_lse;
|
||||||
bool do_fp8_static_quant;
|
bool do_fp8_static_quant;
|
||||||
bool has_sink = false;
|
|
||||||
// TODO: padding check is inside this api
|
// TODO: padding check is inside this api
|
||||||
};
|
};
|
||||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||||
|
|||||||
@@ -907,7 +907,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
|||||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||||
traits.mask_type = mask.type;
|
traits.mask_type = mask.type;
|
||||||
traits.bias_type = bias.type;
|
traits.bias_type = bias.type;
|
||||||
traits.has_sink = mask.sink > 0 ? true : false;
|
|
||||||
traits.has_lse = lse;
|
traits.has_lse = lse;
|
||||||
traits.do_fp8_static_quant = squant;
|
traits.do_fp8_static_quant = squant;
|
||||||
|
|
||||||
@@ -1073,7 +1072,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
|||||||
|
|
||||||
args.window_size_left = mask.left;
|
args.window_size_left = mask.left;
|
||||||
args.window_size_right = mask.right;
|
args.window_size_right = mask.right;
|
||||||
args.sink_size = mask.sink;
|
|
||||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||||
|
|
||||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||||
@@ -1662,7 +1660,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
|||||||
ck_tile::reference_batched_masking<SaccDataType>(
|
ck_tile::reference_batched_masking<SaccDataType>(
|
||||||
s_host_ref,
|
s_host_ref,
|
||||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||||
mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k));
|
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -1674,7 +1672,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
|||||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||||
mask.left,
|
mask.left,
|
||||||
mask.right,
|
mask.right,
|
||||||
mask.sink,
|
|
||||||
real_seqlen_q,
|
real_seqlen_q,
|
||||||
real_seqlen_k,
|
real_seqlen_k,
|
||||||
mask.type == mask_enum::mask_top_left));
|
mask.type == mask_enum::mask_top_left));
|
||||||
@@ -1684,7 +1681,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
|||||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||||
mask.left,
|
mask.left,
|
||||||
mask.right,
|
mask.right,
|
||||||
mask.sink,
|
|
||||||
real_seqlen_q,
|
real_seqlen_q,
|
||||||
real_seqlen_k,
|
real_seqlen_k,
|
||||||
mask.type == mask_enum::mask_top_left));
|
mask.type == mask_enum::mask_top_left));
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ struct mask_info
|
|||||||
ck_tile::index_t seqlen_k;
|
ck_tile::index_t seqlen_k;
|
||||||
ck_tile::index_t y, x;
|
ck_tile::index_t y, x;
|
||||||
ck_tile::index_t left, right; // FA style SWA left/right
|
ck_tile::index_t left, right; // FA style SWA left/right
|
||||||
ck_tile::index_t sink;
|
|
||||||
|
|
||||||
void serialize(std::ostream& os) const
|
void serialize(std::ostream& os) const
|
||||||
{
|
{
|
||||||
@@ -59,14 +58,13 @@ struct mask_info
|
|||||||
ck_tile::index_t window_size = std::stoi(v);
|
ck_tile::index_t window_size = std::stoi(v);
|
||||||
ck_tile::index_t left_size = -1;
|
ck_tile::index_t left_size = -1;
|
||||||
ck_tile::index_t right_size = 0;
|
ck_tile::index_t right_size = 0;
|
||||||
ck_tile::index_t sink_size = 0;
|
|
||||||
if(window_size > 0)
|
if(window_size > 0)
|
||||||
{
|
{
|
||||||
left_size = window_size / 2;
|
left_size = window_size / 2;
|
||||||
right_size = window_size - 1 - left_size;
|
right_size = window_size - 1 - left_size;
|
||||||
}
|
}
|
||||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||||
left_size, right_size, sink_size, y_total, x_total, t == "xt");
|
left_size, right_size, y_total, x_total, t == "xt");
|
||||||
|
|
||||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||||
tmp.y = r.at(ck_tile::number<0>{});
|
tmp.y = r.at(ck_tile::number<0>{});
|
||||||
@@ -81,54 +79,27 @@ struct mask_info
|
|||||||
{
|
{
|
||||||
throw std::invalid_argument("invalid mask value: " + str);
|
throw std::invalid_argument("invalid mask value: " + str);
|
||||||
}
|
}
|
||||||
tmp.type = mask_enum::window_generic;
|
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
||||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
||||||
auto found_2 = v.find(',', found_1 + 1);
|
|
||||||
ck_tile::index_t v1 = 0;
|
|
||||||
ck_tile::index_t sink = 0;
|
|
||||||
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
||||||
// TODO: some validation
|
|
||||||
if(t == "t")
|
if(t == "t")
|
||||||
{
|
{
|
||||||
if(found_2 != std::string::npos)
|
|
||||||
{
|
|
||||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
|
||||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
||||||
sink = 0;
|
|
||||||
}
|
|
||||||
tmp.type = mask_enum::mask_top_left;
|
tmp.type = mask_enum::mask_top_left;
|
||||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||||
v0, v1, sink, y_total, x_total, true);
|
v0, v1, y_total, x_total, true);
|
||||||
tmp.y = r.at(ck_tile::number<0>{});
|
tmp.y = r.at(ck_tile::number<0>{});
|
||||||
tmp.x = r.at(ck_tile::number<1>{});
|
tmp.x = r.at(ck_tile::number<1>{});
|
||||||
tmp.left = v0;
|
tmp.left = v0;
|
||||||
tmp.right = v1;
|
tmp.right = v1;
|
||||||
tmp.sink = sink;
|
|
||||||
}
|
}
|
||||||
else if(t == "b")
|
else if(t == "b")
|
||||||
{
|
{
|
||||||
if(found_2 != std::string::npos)
|
|
||||||
{
|
|
||||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
|
||||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
|
||||||
sink = 0;
|
|
||||||
}
|
|
||||||
tmp.type = mask_enum::mask_bottom_right;
|
tmp.type = mask_enum::mask_bottom_right;
|
||||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||||
v0, v1, sink, y_total, x_total, false);
|
v0, v1, y_total, x_total, false);
|
||||||
tmp.y = r.at(ck_tile::number<0>{});
|
tmp.y = r.at(ck_tile::number<0>{});
|
||||||
tmp.x = r.at(ck_tile::number<1>{});
|
tmp.x = r.at(ck_tile::number<1>{});
|
||||||
tmp.left = v0;
|
tmp.left = v0;
|
||||||
tmp.right = v1;
|
tmp.right = v1;
|
||||||
tmp.sink = sink;
|
|
||||||
}
|
}
|
||||||
else if(t == "g")
|
else if(t == "g")
|
||||||
{
|
{
|
||||||
@@ -137,7 +108,6 @@ struct mask_info
|
|||||||
tmp.x = v1;
|
tmp.x = v1;
|
||||||
tmp.left = v0; // TODO: don't use this?
|
tmp.left = v0; // TODO: don't use this?
|
||||||
tmp.right = v1;
|
tmp.right = v1;
|
||||||
tmp.sink = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -156,7 +126,6 @@ struct mask_info
|
|||||||
tmp.x = 1;
|
tmp.x = 1;
|
||||||
tmp.left = -1;
|
tmp.left = -1;
|
||||||
tmp.right = 0;
|
tmp.right = 0;
|
||||||
tmp.sink = 0;
|
|
||||||
}
|
}
|
||||||
else if(str == "2" || str == "b")
|
else if(str == "2" || str == "b")
|
||||||
{
|
{
|
||||||
@@ -165,7 +134,6 @@ struct mask_info
|
|||||||
tmp.x = seqlen_k - seqlen_q + 1;
|
tmp.x = seqlen_k - seqlen_q + 1;
|
||||||
tmp.left = -1;
|
tmp.left = -1;
|
||||||
tmp.right = 0;
|
tmp.right = 0;
|
||||||
tmp.sink = 0;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# TODO: run this script from CK root or build directory
|
|
||||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
|
||||||
KNAME=1
|
|
||||||
|
|
||||||
export CK_WARMUP=0
|
|
||||||
export CK_REPEAT=1
|
|
||||||
|
|
||||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
|
||||||
# mode=0
|
|
||||||
# export HIP_VISIBLE_DEVICES=4
|
|
||||||
|
|
||||||
TEST_SPLITKV=0
|
|
||||||
TEST_APPENDKV=0
|
|
||||||
# options:
|
|
||||||
# -s: run splitkv tests
|
|
||||||
# -a: run appendkv tests
|
|
||||||
while getopts ":sa" opt; do
|
|
||||||
case "${opt}" in
|
|
||||||
s)
|
|
||||||
TEST_SPLITKV=1
|
|
||||||
;;
|
|
||||||
a)
|
|
||||||
TEST_APPENDKV=1
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
run_fp16_bf16_tests() {
|
|
||||||
local NUM_SPLITS="1"
|
|
||||||
local PAGE_BLOCK_SIZE="0"
|
|
||||||
local CACHE_BATCH_IDX="0"
|
|
||||||
|
|
||||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
|
||||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
|
||||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
|
||||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
|
||||||
fi
|
|
||||||
|
|
||||||
for prec in "fp16"; do
|
|
||||||
for mode in 1 0 ; do
|
|
||||||
for perm in 0 1 ; do
|
|
||||||
for vlayout in "r" "c" ; do
|
|
||||||
for batch in 1 4; do
|
|
||||||
for head in 1; do
|
|
||||||
for h_k in 1; do
|
|
||||||
for q_seq in 128 512 ; do
|
|
||||||
for kv_seq in 128 1024; do
|
|
||||||
for hdim in 32 64 128 256; do #256
|
|
||||||
for lse in 0 1 ; do
|
|
||||||
for bias in "e" ; do
|
|
||||||
for p_drop in 0.0 0.2; do # 0.0
|
|
||||||
for mask in "t:2,0,4" "b:1,0,2"; do
|
|
||||||
for num_splits in $NUM_SPLITS ; do
|
|
||||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
|
||||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
|
||||||
|
|
||||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
|
||||||
$EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16, -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask
|
|
||||||
|
|
||||||
done ; done ; done ; done ; done
|
|
||||||
done ; done ; done ; done ; done
|
|
||||||
done ; done ; done ; done ; done
|
|
||||||
done ; done
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
run_fp16_bf16_tests
|
|
||||||
|
|
||||||
set +x
|
|
||||||
@@ -36,7 +36,6 @@ function print_log_header(){
|
|||||||
#run verification tests
|
#run verification tests
|
||||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
|
||||||
|
|
||||||
#run performance benchmarks
|
#run performance benchmarks
|
||||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# TODO: run this script from CK root or build directory
|
|
||||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
|
||||||
EXE_NAME=tile_example_fmha_fwd
|
|
||||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
|
||||||
KNAME=1
|
|
||||||
GPU_arch=$GPU_arch
|
|
||||||
if [ -z "$GPU_arch" ] ; then
|
|
||||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
|
||||||
fi
|
|
||||||
set -x
|
|
||||||
|
|
||||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
|
||||||
|
|
||||||
|
|
||||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
|
||||||
|
|
||||||
# window_size[2,0], sink_size = 2
|
|
||||||
|
|
||||||
# x=1/y=3
|
|
||||||
# 1 * * * * * * * 1 * * * * * * *
|
|
||||||
# 1 1 * * * * * * 1 1 * * * * * *
|
|
||||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
|
||||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
|
||||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
|
||||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
|
||||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
|
||||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
|
||||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
|
||||||
|
|
||||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
|
||||||
|
|
||||||
# x=4/y=1
|
|
||||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
|
||||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
|
||||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
|
||||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
|
||||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
|
||||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
|
||||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
|
||||||
|
|
||||||
|
|
||||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
|
||||||
|
|
||||||
# x=4/y=-1
|
|
||||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
|
||||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
|
||||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
|
||||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
|
||||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
|
||||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
|
||||||
|
|
||||||
|
|
||||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
|
||||||
|
|
||||||
# x=-1/y=5
|
|
||||||
|
|
||||||
# * * * * * * * * * * * *
|
|
||||||
# * * * * * * * * * * * *
|
|
||||||
# 1 * * * * * 1 * * * * *
|
|
||||||
# 1 1 * * * * 1 1 * * * *
|
|
||||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
|
||||||
# * 1 1 1 * * 1 1 1 1 * *
|
|
||||||
# * * 1 1 1 * 1 1 1 1 1 *
|
|
||||||
# * * * 1 1 1 1 1 * 1 1 1
|
|
||||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
|
||||||
|
|
||||||
|
|
||||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
|
||||||
# x=-1/y=8
|
|
||||||
# * * * * * * * * * *
|
|
||||||
# * * * * * * * * * *
|
|
||||||
# 1 * * * * ----> 1 * * * *
|
|
||||||
# 1 1 * * * 1 1 * * *
|
|
||||||
# 1 1 1 * * 1 1 1 * *
|
|
||||||
# 1 1 1 1 * 1 1 1 1 *
|
|
||||||
# 1 1 1 1 1 1 1 1 1 1
|
|
||||||
# 1 1 1 1 1 1 1 1 1 1
|
|
||||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
|
|||||||
{
|
{
|
||||||
for(int m = 0; m < M; ++m)
|
for(int m = 0; m < M; ++m)
|
||||||
{
|
{
|
||||||
if(mask.IsOutOfSinkBound(m, n))
|
if(mask.IsOutOfBound(m, n))
|
||||||
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,22 +86,21 @@ struct GenericAttentionMask
|
|||||||
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
|
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||||
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
: GenericAttentionMask(0, 0, y_total_, x_total_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE
|
CK_TILE_HOST_DEVICE
|
||||||
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
template <typename MaskCoordinates>
|
template <typename MaskCoordinates>
|
||||||
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
|
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||||
: y(mask_coord.at(number<0>{})),
|
: y(mask_coord.at(number<0>{})),
|
||||||
x(mask_coord.at(number<1>{})),
|
x(mask_coord.at(number<1>{})),
|
||||||
sink(mask_coord.at(number<2>{})),
|
y_total(mask_coord.at(number<2>{})),
|
||||||
y_total(mask_coord.at(number<3>{})),
|
x_total(mask_coord.at(number<3>{}))
|
||||||
x_total(mask_coord.at(number<4>{}))
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,44 +141,6 @@ struct GenericAttentionMask
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <index_t YTile, index_t XTile>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto
|
|
||||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
|
||||||
{
|
|
||||||
if constexpr(!IsMasking)
|
|
||||||
{
|
|
||||||
return ck_tile::make_tuple(0, 0, x_total);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// get the tile start/end range assum we loop over along X tile by tile
|
|
||||||
index_t x_start = [&]() {
|
|
||||||
if constexpr(IsLocal)
|
|
||||||
{
|
|
||||||
index_t tmp = max(-y + i_y + 1, 0);
|
|
||||||
return (tmp / XTile) * XTile; // round to tile aligned
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
|
|
||||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
|
||||||
// ... in which case end-start is negative
|
|
||||||
index_t x_end = [&]() {
|
|
||||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
|
||||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
|
||||||
}();
|
|
||||||
|
|
||||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
|
||||||
if(x_start <= sink_seq_end && sink > 0)
|
|
||||||
return ck_tile::make_tuple(0, 0, x_end);
|
|
||||||
else
|
|
||||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||||
@@ -234,30 +195,6 @@ struct GenericAttentionMask
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
|
||||||
{
|
|
||||||
if constexpr(!IsMasking)
|
|
||||||
return i_x >= x_total;
|
|
||||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
|
||||||
index_t x_start = -y + i_y + 1;
|
|
||||||
index_t x_end = min(i_y + x, x_total);
|
|
||||||
|
|
||||||
if constexpr(IsLocal)
|
|
||||||
{
|
|
||||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
|
||||||
return false;
|
|
||||||
else
|
|
||||||
return i_x < x_start || i_x >= x_end;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
|
||||||
return false;
|
|
||||||
else
|
|
||||||
return i_x >= x_end || i_y >= y_total;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if current tile is at the edge, means need per-pixel mask check.
|
// if current tile is at the edge, means need per-pixel mask check.
|
||||||
// otherwise no need to check per-pixel
|
// otherwise no need to check per-pixel
|
||||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||||
@@ -300,7 +237,7 @@ struct GenericAttentionMask
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
index_t y, x, sink;
|
index_t y, x;
|
||||||
index_t y_total, x_total;
|
index_t y_total, x_total;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -323,23 +260,21 @@ struct SimplifiedGenericAttentionMask
|
|||||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||||
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE
|
CK_TILE_HOST_DEVICE
|
||||||
SimplifiedGenericAttentionMask(
|
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||||
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
template <typename MaskCoordinates>
|
template <typename MaskCoordinates>
|
||||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||||
: y(mask_coord.at(number<0>{})),
|
: y(mask_coord.at(number<0>{})),
|
||||||
x(mask_coord.at(number<1>{})),
|
x(mask_coord.at(number<1>{})),
|
||||||
sink(mask_coord.at(number<2>{})),
|
y_total(mask_coord.at(number<2>{})),
|
||||||
y_total(mask_coord.at(number<3>{})),
|
x_total(mask_coord.at(number<3>{}))
|
||||||
x_total(mask_coord.at(number<4>{}))
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,38 +308,6 @@ struct SimplifiedGenericAttentionMask
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <index_t YTile, index_t XTile>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto
|
|
||||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
|
||||||
{
|
|
||||||
if constexpr(!IsMasking)
|
|
||||||
{
|
|
||||||
return ck_tile::make_tuple(0, 0, x_total);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// get the tile start/end range assum we loop over along X tile by tile
|
|
||||||
index_t x_start = [&]() {
|
|
||||||
index_t tmp = max(-y + i_y + 1, 0);
|
|
||||||
return (tmp / XTile) * XTile; // round to tile aligned
|
|
||||||
}();
|
|
||||||
|
|
||||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
|
||||||
// ... in which case end-start is negative
|
|
||||||
index_t x_end = [&]() {
|
|
||||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
|
||||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
|
||||||
}();
|
|
||||||
|
|
||||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
|
||||||
|
|
||||||
if(x_start <= sink_seq_end && sink > 0)
|
|
||||||
return ck_tile::make_tuple(0, 0, x_end);
|
|
||||||
else
|
|
||||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <index_t TileHeight, index_t TileWidth>
|
template <index_t TileHeight, index_t TileWidth>
|
||||||
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
|
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
|
||||||
number<TileHeight> height,
|
number<TileHeight> height,
|
||||||
@@ -422,29 +325,6 @@ struct SimplifiedGenericAttentionMask
|
|||||||
ck_tile::min(origin_end, split_end));
|
ck_tile::min(origin_end, split_end));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <index_t TileHeight, index_t TileWidth>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
|
|
||||||
number<TileHeight> height,
|
|
||||||
number<TileWidth> width,
|
|
||||||
index_t num_splits,
|
|
||||||
index_t i_split) const
|
|
||||||
{
|
|
||||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
|
||||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
|
||||||
const index_t split_start = x_per_split * i_split; // 128
|
|
||||||
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
|
|
||||||
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
|
|
||||||
const index_t start = ck_tile::max(origin_start, split_start);
|
|
||||||
const index_t end = ck_tile::min(origin_end, split_end);
|
|
||||||
const bool is_first_intersecting_split =
|
|
||||||
(split_start <= origin_start && split_end >= origin_start);
|
|
||||||
const bool sink_in_range = (sink_seq_end <= start);
|
|
||||||
|
|
||||||
const index_t sink_offset =
|
|
||||||
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
|
|
||||||
return ck_tile::make_tuple(sink_offset, start, end);
|
|
||||||
}
|
|
||||||
|
|
||||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||||
@@ -488,22 +368,11 @@ struct SimplifiedGenericAttentionMask
|
|||||||
{
|
{
|
||||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||||
|
|
||||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
|
||||||
{
|
|
||||||
if constexpr(!IsMasking)
|
|
||||||
return i_x >= x_total;
|
|
||||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
|
||||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
|
||||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
|
||||||
return false;
|
|
||||||
else
|
|
||||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if current tile is at the edge, means need per-pixel mask check.
|
// if current tile is at the edge, means need per-pixel mask check.
|
||||||
// otherwise no need to check per-pixel
|
// otherwise no need to check per-pixel
|
||||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||||
@@ -537,7 +406,7 @@ struct SimplifiedGenericAttentionMask
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
index_t y, x, sink;
|
index_t y, x;
|
||||||
index_t y_total, x_total;
|
index_t y_total, x_total;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -738,7 +607,6 @@ struct SimplifiedRatioAttentionMask
|
|||||||
CK_TILE_HOST_DEVICE constexpr auto
|
CK_TILE_HOST_DEVICE constexpr auto
|
||||||
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||||
index_t right_size,
|
index_t right_size,
|
||||||
index_t sink_size,
|
|
||||||
index_t y_total,
|
index_t y_total,
|
||||||
index_t x_total,
|
index_t x_total,
|
||||||
bool is_top_left = true)
|
bool is_top_left = true)
|
||||||
@@ -756,21 +624,7 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
|||||||
index_t x = 1 + right_size + x_tmp;
|
index_t x = 1 + right_size + x_tmp;
|
||||||
index_t y = 1 + left_size + y_tmp;
|
index_t y = 1 + left_size + y_tmp;
|
||||||
|
|
||||||
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
|
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MaskType>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr auto
|
|
||||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
|
||||||
index_t right_size,
|
|
||||||
index_t sink_size,
|
|
||||||
index_t y_total,
|
|
||||||
index_t x_total,
|
|
||||||
bool is_top_left = true)
|
|
||||||
{
|
|
||||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
|
||||||
left_size, right_size, sink_size, y_total, x_total, is_top_left);
|
|
||||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename MaskType>
|
template <typename MaskType>
|
||||||
@@ -782,7 +636,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
|
|||||||
bool is_top_left = true)
|
bool is_top_left = true)
|
||||||
{
|
{
|
||||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||||
left_size, right_size, 0, y_total, x_total, is_top_left);
|
left_size, right_size, y_total, x_total, is_top_left);
|
||||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
|
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||||
}
|
}
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|||||||
@@ -162,17 +162,6 @@ struct StandardAttention
|
|||||||
{
|
{
|
||||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Params>
|
|
||||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
|
||||||
[[maybe_unused]] uint32_t batch_idx,
|
|
||||||
uint32_t qo_idx,
|
|
||||||
uint32_t kv_idx,
|
|
||||||
[[maybe_unused]] uint32_t qo_head_idx,
|
|
||||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
|
||||||
{
|
|
||||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool UseExp2 = false>
|
template <bool UseExp2 = false>
|
||||||
@@ -235,17 +224,6 @@ struct LogitsSoftCap
|
|||||||
{
|
{
|
||||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Params>
|
|
||||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
|
||||||
[[maybe_unused]] uint32_t batch_idx,
|
|
||||||
uint32_t qo_idx,
|
|
||||||
uint32_t kv_idx,
|
|
||||||
[[maybe_unused]] uint32_t qo_head_idx,
|
|
||||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
|
||||||
{
|
|
||||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr uint32_t CUSTOM_MASK = 1U;
|
constexpr uint32_t CUSTOM_MASK = 1U;
|
||||||
@@ -319,17 +297,6 @@ struct ComposedAttention
|
|||||||
{
|
{
|
||||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Params>
|
|
||||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
|
||||||
[[maybe_unused]] uint32_t batch_idx,
|
|
||||||
uint32_t qo_idx,
|
|
||||||
uint32_t kv_idx,
|
|
||||||
[[maybe_unused]] uint32_t qo_head_idx,
|
|
||||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
|
||||||
{
|
|
||||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
struct FmhaFwdMaskKargs
|
struct FmhaFwdMaskKargs
|
||||||
{
|
{
|
||||||
// ck_tile::index_t window_size_left, window_size_right;
|
// ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -362,7 +362,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
bool s_randval,
|
bool s_randval,
|
||||||
@@ -426,7 +425,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -511,7 +509,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
ck_tile::index_t batch_stride_v,
|
ck_tile::index_t batch_stride_v,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
bool s_randval,
|
bool s_randval,
|
||||||
@@ -573,7 +570,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -1030,7 +1026,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
|||||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||||
kargs.window_size_left,
|
kargs.window_size_left,
|
||||||
kargs.window_size_right,
|
kargs.window_size_right,
|
||||||
kargs.sink_size,
|
|
||||||
kargs.seqlen_q,
|
kargs.seqlen_q,
|
||||||
kargs.seqlen_k,
|
kargs.seqlen_k,
|
||||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ struct FmhaFwdKernel
|
|||||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
|
||||||
|
|
||||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||||
@@ -113,7 +112,7 @@ struct FmhaFwdKernel
|
|||||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload") + (kHasSink ? "_sink" : "_nsink");
|
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||||
#undef _SS_
|
#undef _SS_
|
||||||
#undef _TS_
|
#undef _TS_
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@@ -201,7 +200,7 @@ struct FmhaFwdKernel
|
|||||||
struct FmhaFwdMaskKargs
|
struct FmhaFwdMaskKargs
|
||||||
{
|
{
|
||||||
// ck_tile::index_t window_size_left, window_size_right;
|
// ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -375,7 +374,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
bool s_randval,
|
bool s_randval,
|
||||||
@@ -434,7 +432,6 @@ struct FmhaFwdKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -521,7 +518,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
bool s_randval,
|
bool s_randval,
|
||||||
@@ -569,7 +565,6 @@ struct FmhaFwdKernel
|
|||||||
batch_stride_o,
|
batch_stride_o,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type,
|
mask_type,
|
||||||
p_drop,
|
p_drop,
|
||||||
s_randval,
|
s_randval,
|
||||||
@@ -620,7 +615,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
bool s_randval,
|
bool s_randval,
|
||||||
@@ -668,7 +662,6 @@ struct FmhaFwdKernel
|
|||||||
batch_stride_o,
|
batch_stride_o,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type,
|
mask_type,
|
||||||
p_drop,
|
p_drop,
|
||||||
s_randval,
|
s_randval,
|
||||||
@@ -713,7 +706,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t nhead_stride_o,
|
ck_tile::index_t nhead_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
ck_tile::index_t min_seqlen_q,
|
ck_tile::index_t min_seqlen_q,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
@@ -773,7 +765,6 @@ struct FmhaFwdKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -857,7 +848,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t nhead_stride_o,
|
ck_tile::index_t nhead_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
ck_tile::index_t min_seqlen_q,
|
ck_tile::index_t min_seqlen_q,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
@@ -901,7 +891,6 @@ struct FmhaFwdKernel
|
|||||||
nhead_stride_o,
|
nhead_stride_o,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type,
|
mask_type,
|
||||||
min_seqlen_q,
|
min_seqlen_q,
|
||||||
p_drop,
|
p_drop,
|
||||||
@@ -948,7 +937,6 @@ struct FmhaFwdKernel
|
|||||||
ck_tile::index_t nhead_stride_o,
|
ck_tile::index_t nhead_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
ck_tile::index_t min_seqlen_q,
|
ck_tile::index_t min_seqlen_q,
|
||||||
float p_drop,
|
float p_drop,
|
||||||
@@ -992,7 +980,6 @@ struct FmhaFwdKernel
|
|||||||
nhead_stride_o,
|
nhead_stride_o,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type,
|
mask_type,
|
||||||
min_seqlen_q,
|
min_seqlen_q,
|
||||||
p_drop,
|
p_drop,
|
||||||
@@ -1484,7 +1471,6 @@ struct FmhaFwdKernel
|
|||||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||||
kargs.window_size_left,
|
kargs.window_size_left,
|
||||||
kargs.window_size_right,
|
kargs.window_size_right,
|
||||||
kargs.sink_size,
|
|
||||||
kargs.seqlen_q,
|
kargs.seqlen_q,
|
||||||
kargs.seqlen_k,
|
kargs.seqlen_k,
|
||||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||||
@@ -2214,7 +2200,6 @@ struct FmhaFwdKernel
|
|||||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||||
kargs.window_size_left,
|
kargs.window_size_left,
|
||||||
kargs.window_size_right,
|
kargs.window_size_right,
|
||||||
kargs.sink_size,
|
|
||||||
kargs.seqlen_q,
|
kargs.seqlen_q,
|
||||||
kargs.seqlen_k,
|
kargs.seqlen_k,
|
||||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
|
||||||
|
|
||||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||||
@@ -102,7 +101,7 @@ struct FmhaFwdPagedKVKernel
|
|||||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||||
#undef _SS_
|
#undef _SS_
|
||||||
#undef _TS_
|
#undef _TS_
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@@ -190,7 +189,7 @@ struct FmhaFwdPagedKVKernel
|
|||||||
struct FmhaFwdMaskKargs
|
struct FmhaFwdMaskKargs
|
||||||
{
|
{
|
||||||
// ck_tile::index_t window_size_left, window_size_right;
|
// ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -327,7 +326,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type)
|
ck_tile::index_t mask_type)
|
||||||
{
|
{
|
||||||
Kargs kargs{{q_ptr,
|
Kargs kargs{{q_ptr,
|
||||||
@@ -381,7 +379,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -456,7 +453,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
ck_tile::index_t batch_stride_o,
|
ck_tile::index_t batch_stride_o,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type)
|
ck_tile::index_t mask_type)
|
||||||
{
|
{
|
||||||
return MakeKargsImpl(q_ptr,
|
return MakeKargsImpl(q_ptr,
|
||||||
@@ -499,7 +495,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
batch_stride_o,
|
batch_stride_o,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type);
|
mask_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -541,7 +536,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
ck_tile::index_t min_seqlen_q)
|
ck_tile::index_t min_seqlen_q)
|
||||||
{
|
{
|
||||||
@@ -596,7 +590,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kStoreLSE)
|
if constexpr(kStoreLSE)
|
||||||
@@ -667,7 +660,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type,
|
ck_tile::index_t mask_type,
|
||||||
ck_tile::index_t min_seqlen_q)
|
ck_tile::index_t min_seqlen_q)
|
||||||
{
|
{
|
||||||
@@ -707,7 +699,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
batch_stride_v,
|
batch_stride_v,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
window_size_right,
|
window_size_right,
|
||||||
sink_size,
|
|
||||||
mask_type,
|
mask_type,
|
||||||
min_seqlen_q);
|
min_seqlen_q);
|
||||||
}
|
}
|
||||||
@@ -1266,7 +1257,6 @@ struct FmhaFwdPagedKVKernel
|
|||||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||||
kargs.window_size_left,
|
kargs.window_size_left,
|
||||||
kargs.window_size_right,
|
kargs.window_size_right,
|
||||||
kargs.sink_size,
|
|
||||||
kargs.seqlen_q,
|
kargs.seqlen_q,
|
||||||
kargs.seqlen_k,
|
kargs.seqlen_k,
|
||||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||||
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
|
|
||||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
||||||
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
||||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||||
@@ -102,7 +101,7 @@ struct FmhaFwdSplitKVKernel
|
|||||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
|
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
|
||||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||||
#undef _SS_
|
#undef _SS_
|
||||||
#undef _TS_
|
#undef _TS_
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@@ -199,7 +198,7 @@ struct FmhaFwdSplitKVKernel
|
|||||||
struct MaskKargs
|
struct MaskKargs
|
||||||
{
|
{
|
||||||
// ck_tile::index_t window_size_left, window_size_right;
|
// ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
ck_tile::index_t window_size_left, window_size_right;
|
||||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -326,7 +325,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
ck_tile::index_t split_stride_o_acc,
|
ck_tile::index_t split_stride_o_acc,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type)
|
ck_tile::index_t mask_type)
|
||||||
{
|
{
|
||||||
Kargs kargs{{q_ptr,
|
Kargs kargs{{q_ptr,
|
||||||
@@ -386,7 +384,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kDoFp8StaticQuant)
|
if constexpr(kDoFp8StaticQuant)
|
||||||
@@ -454,7 +451,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
ck_tile::index_t split_stride_o_acc,
|
ck_tile::index_t split_stride_o_acc,
|
||||||
ck_tile::index_t window_size_left,
|
ck_tile::index_t window_size_left,
|
||||||
ck_tile::index_t window_size_right,
|
ck_tile::index_t window_size_right,
|
||||||
ck_tile::index_t sink_size,
|
|
||||||
ck_tile::index_t mask_type)
|
ck_tile::index_t mask_type)
|
||||||
{
|
{
|
||||||
Kargs kargs{{q_ptr,
|
Kargs kargs{{q_ptr,
|
||||||
@@ -512,7 +508,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
{
|
{
|
||||||
kargs.window_size_left = window_size_left;
|
kargs.window_size_left = window_size_left;
|
||||||
kargs.window_size_right = window_size_right;
|
kargs.window_size_right = window_size_right;
|
||||||
kargs.sink_size = sink_size;
|
|
||||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||||
}
|
}
|
||||||
if constexpr(kDoFp8StaticQuant)
|
if constexpr(kDoFp8StaticQuant)
|
||||||
@@ -999,7 +994,6 @@ struct FmhaFwdSplitKVKernel
|
|||||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||||
kargs.window_size_left,
|
kargs.window_size_left,
|
||||||
kargs.window_size_right,
|
kargs.window_size_right,
|
||||||
kargs.sink_size,
|
|
||||||
kargs.seqlen_q,
|
kargs.seqlen_q,
|
||||||
kargs.seqlen_k,
|
kargs.seqlen_k,
|
||||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||||
@@ -229,22 +228,10 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
clear_tile(o_acc);
|
clear_tile(o_acc);
|
||||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||||
clear_tile(l);
|
clear_tile(l);
|
||||||
const auto q_origin = q_dram_window.get_window_origin();
|
|
||||||
const auto tile_range_result = [&mask, &q_origin]() {
|
const auto q_origin = q_dram_window.get_window_origin();
|
||||||
if constexpr(kHasSink)
|
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||||
return mask.GetSinkTileRangeAlongX(
|
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
else
|
|
||||||
{
|
|
||||||
auto [start, end] =
|
|
||||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
return ck_tile::make_tuple(0, start, end);
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
|
||||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
|
||||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
|
||||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
|
||||||
|
|
||||||
// check early exit if no work to do
|
// check early exit if no work to do
|
||||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||||
@@ -268,6 +255,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
return o_acc;
|
return o_acc;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// k_dram_block_window
|
// k_dram_block_window
|
||||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||||
@@ -286,36 +274,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
return physical_seqlen_k_start_;
|
return physical_seqlen_k_start_;
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
|
||||||
? aligned_physical_seqlen_k_start
|
|
||||||
: 0;
|
|
||||||
const index_t num_total_loop =
|
const index_t num_total_loop =
|
||||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||||
num_sink_loop;
|
|
||||||
|
|
||||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||||
|
|
||||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
|
||||||
const index_t bias_n_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return kv_load_start;
|
|
||||||
else
|
|
||||||
return logical_seqlen_k_start -
|
|
||||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
|
||||||
}();
|
|
||||||
|
|
||||||
|
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||||
auto bias_dram_window =
|
auto bias_dram_window =
|
||||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
bias_dram_block_window_tmp.get_window_lengths(),
|
bias_dram_block_window_tmp.get_window_lengths(),
|
||||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
{bias_origin.at(number<0>{}),
|
||||||
|
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||||
|
aligned_physical_seqlen_k_start)}, // M/N
|
||||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||||
|
|
||||||
// v_dram_window
|
// v_dram_window
|
||||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||||
v_dram_block_window_lengths,
|
v_dram_block_window_lengths,
|
||||||
{0, kv_load_start}, // TODO: hdim split?
|
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||||
|
|
||||||
// prefetch K tile
|
// prefetch K tile
|
||||||
@@ -342,16 +321,9 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||||
k_block_tile = load_tile(k_dram_window);
|
k_block_tile = load_tile(k_dram_window);
|
||||||
}
|
}
|
||||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
|
||||||
const auto k_move_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
|
||||||
else
|
|
||||||
return kN0;
|
|
||||||
}();
|
|
||||||
auto physical_next_block_id_k =
|
auto physical_next_block_id_k =
|
||||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||||
|
|
||||||
@@ -470,7 +442,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
move_tile_window(bias_dram_window, {0, kN0});
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||||
@@ -502,29 +474,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
number<kN0>{});
|
number<kN0>{});
|
||||||
if(need_perpixel_check)
|
if(need_perpixel_check)
|
||||||
{
|
{
|
||||||
auto apply_mask = [&](auto&& mask_func) {
|
set_tile_if(
|
||||||
set_tile_if(s_acc,
|
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||||
-numeric<SMPLComputeDataType>::infinity(),
|
const auto row =
|
||||||
[&](auto tile_idx) {
|
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||||
const auto row =
|
const auto col =
|
||||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
const auto col =
|
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
|
||||||
return mask_func(row, col - kv_l2p_offset);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
apply_mask([&](auto row, auto col) {
|
|
||||||
return mask.IsOutOfSinkBound(row, col);
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
apply_mask(
|
|
||||||
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -690,12 +647,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
|||||||
}
|
}
|
||||||
// move K tile windows
|
// move K tile windows
|
||||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||||
physical_next_block_id_v =
|
|
||||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
|
||||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
|
||||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
|
||||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
|
||||||
// tail
|
// tail
|
||||||
{
|
{
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||||
@@ -257,23 +256,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||||
clear_tile(l);
|
clear_tile(l);
|
||||||
|
|
||||||
const auto q_origin = q_dram_window.get_window_origin();
|
const auto q_origin = q_dram_window.get_window_origin();
|
||||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||||
if constexpr(kHasSink)
|
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||||
return mask.GetSinkTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
auto [start, end] = mask.GetTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
|
||||||
return ck_tile::make_tuple(0, start, end);
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
|
||||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
|
||||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
|
||||||
|
|
||||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
// check early exit if no work to do
|
||||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||||
{
|
{
|
||||||
const index_t logical_num_total_loop =
|
const index_t logical_num_total_loop =
|
||||||
@@ -317,33 +304,24 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
return physical_seqlen_k_start_;
|
return physical_seqlen_k_start_;
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
|
||||||
? aligned_physical_seqlen_k_start
|
|
||||||
: 0;
|
|
||||||
const index_t num_total_loop =
|
const index_t num_total_loop =
|
||||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||||
num_sink_loop;
|
|
||||||
|
|
||||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||||
|
|
||||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||||
const index_t bias_n_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return kv_load_start;
|
|
||||||
else
|
|
||||||
return logical_seqlen_k_start -
|
|
||||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
|
||||||
}();
|
|
||||||
auto bias_dram_window =
|
auto bias_dram_window =
|
||||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
bias_dram_block_window_tmp.get_window_lengths(),
|
bias_dram_block_window_tmp.get_window_lengths(),
|
||||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
{bias_origin.at(number<0>{}),
|
||||||
|
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||||
|
aligned_physical_seqlen_k_start)}, // M/N
|
||||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||||
|
|
||||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||||
v_dram_block_window_lengths,
|
v_dram_block_window_lengths,
|
||||||
{0, kv_load_start}, // TODO: hdim split?
|
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
// store Q into LDS
|
// store Q into LDS
|
||||||
@@ -391,13 +369,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
{
|
{
|
||||||
// STAGE 1, QK gemm
|
// STAGE 1, QK gemm
|
||||||
clear_tile(s_acc); // initialize C
|
clear_tile(s_acc); // initialize C
|
||||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
|
||||||
const auto k_move_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
|
||||||
else
|
|
||||||
return kN0;
|
|
||||||
}();
|
|
||||||
// load the second tile of the first iteration
|
// load the second tile of the first iteration
|
||||||
k_block_tile = load_tile(k_dram_window);
|
k_block_tile = load_tile(k_dram_window);
|
||||||
|
|
||||||
@@ -522,7 +494,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
move_tile_window(bias_dram_window, {0, kN0});
|
||||||
|
|
||||||
/// TODO: only check in first/last iteration without increasing code size
|
/// TODO: only check in first/last iteration without increasing code size
|
||||||
if constexpr(kHasUnevenSplits)
|
if constexpr(kHasUnevenSplits)
|
||||||
@@ -533,7 +505,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
s_acc,
|
s_acc,
|
||||||
-numeric<SMPLComputeDataType>::infinity(),
|
-numeric<SMPLComputeDataType>::infinity(),
|
||||||
[&,
|
[&,
|
||||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
if constexpr(kIsPagedKV)
|
if constexpr(kIsPagedKV)
|
||||||
@@ -558,26 +530,12 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
number<kN0>{});
|
number<kN0>{});
|
||||||
if(need_perpixel_check)
|
if(need_perpixel_check)
|
||||||
{
|
{
|
||||||
auto apply_mask = [&](auto&& mask_func) {
|
set_tile_if(
|
||||||
set_tile_if(
|
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||||
const auto row =
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||||
const auto col =
|
});
|
||||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
|
||||||
return mask_func(row, col - kv_l2p_offset);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
apply_mask(
|
|
||||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -588,7 +546,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
{
|
{
|
||||||
// move K tile windows
|
// move K tile windows
|
||||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
|
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||||
|
|
||||||
k_dram_window = make_tile_window(
|
k_dram_window = make_tile_window(
|
||||||
k_dram_block_window,
|
k_dram_block_window,
|
||||||
@@ -784,8 +742,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
|||||||
// moving k_dram_window is an in-page-block operation, so there is
|
// moving k_dram_window is an in-page-block operation, so there is
|
||||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||||
move_tile_window(k_dram_window, {0, kK0});
|
move_tile_window(k_dram_window, {0, kK0});
|
||||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
|
||||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
|
|
||||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||||
}
|
}
|
||||||
} while(++i_total_loops < num_total_loop);
|
} while(++i_total_loops < num_total_loop);
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||||
@@ -230,23 +229,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||||
clear_tile(l);
|
clear_tile(l);
|
||||||
|
|
||||||
const auto q_origin = q_dram_window.get_window_origin();
|
const auto q_origin = q_dram_window.get_window_origin();
|
||||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||||
if constexpr(kHasSink)
|
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||||
return mask.GetSinkTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
|
||||||
else
|
|
||||||
{
|
|
||||||
auto [start, end] = mask.GetTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
|
||||||
return ck_tile::make_tuple(0, start, end);
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
|
||||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
|
||||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
|
||||||
|
|
||||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
|
||||||
|
|
||||||
// check early exit if no work to do
|
// check early exit if no work to do
|
||||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||||
@@ -289,35 +274,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
return physical_seqlen_k_start_;
|
return physical_seqlen_k_start_;
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
|
||||||
? aligned_physical_seqlen_k_start
|
|
||||||
: 0;
|
|
||||||
const index_t num_total_loop =
|
const index_t num_total_loop =
|
||||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||||
num_sink_loop;
|
|
||||||
|
|
||||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||||
|
|
||||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||||
|
|
||||||
const index_t bias_n_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return kv_load_start;
|
|
||||||
else
|
|
||||||
return logical_seqlen_k_start -
|
|
||||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
|
||||||
}();
|
|
||||||
|
|
||||||
auto bias_dram_window =
|
auto bias_dram_window =
|
||||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
bias_dram_block_window_tmp.get_window_lengths(),
|
bias_dram_block_window_tmp.get_window_lengths(),
|
||||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
{bias_origin.at(number<0>{}),
|
||||||
|
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||||
|
aligned_physical_seqlen_k_start)}, // M/N
|
||||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||||
|
|
||||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||||
v_dram_block_window_lengths,
|
v_dram_block_window_lengths,
|
||||||
{0, kv_load_start}, // TODO: hdim split?
|
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||||
@@ -346,18 +320,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||||
k_block_tile = load_tile(k_dram_window);
|
k_block_tile = load_tile(k_dram_window);
|
||||||
}
|
}
|
||||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
|
||||||
|
|
||||||
const auto k_move_offset = [&]() {
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
|
||||||
else
|
|
||||||
return kN0;
|
|
||||||
}();
|
|
||||||
|
|
||||||
auto physical_next_block_id_k =
|
auto physical_next_block_id_k =
|
||||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||||
|
|
||||||
@@ -476,7 +441,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
move_tile_window(bias_dram_window, {0, kN0});
|
||||||
|
|
||||||
/// TODO: only check in first/last iteration without increasing code size
|
/// TODO: only check in first/last iteration without increasing code size
|
||||||
if constexpr(kHasUnevenSplits)
|
if constexpr(kHasUnevenSplits)
|
||||||
@@ -487,7 +452,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
s_acc,
|
s_acc,
|
||||||
-numeric<SMPLComputeDataType>::infinity(),
|
-numeric<SMPLComputeDataType>::infinity(),
|
||||||
[&,
|
[&,
|
||||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
if constexpr(kIsPagedKV)
|
if constexpr(kIsPagedKV)
|
||||||
@@ -512,26 +477,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
number<kN0>{});
|
number<kN0>{});
|
||||||
if(need_perpixel_check)
|
if(need_perpixel_check)
|
||||||
{
|
{
|
||||||
auto apply_mask = [&](auto&& mask_func) {
|
set_tile_if(
|
||||||
set_tile_if(
|
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||||
const auto row =
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||||
const auto col =
|
});
|
||||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
|
||||||
return mask_func(row, col - kv_l2p_offset);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
apply_mask(
|
|
||||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -696,12 +647,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
|||||||
}
|
}
|
||||||
// move K tile windows
|
// move K tile windows
|
||||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||||
physical_next_block_id_v =
|
|
||||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
|
||||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
|
||||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
|
||||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
|
||||||
// tail
|
// tail
|
||||||
{
|
{
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ struct BlockFmhaPipelineProblem
|
|||||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||||
static constexpr bool kHasSink = Traits::kHasSink;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename QDataType_,
|
template <typename QDataType_,
|
||||||
@@ -115,7 +114,6 @@ struct BlockFmhaFwdPagedKVPipelineProblem
|
|||||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||||
static constexpr bool kHasSink = Traits::kHasSink;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename QDataType_,
|
template <typename QDataType_,
|
||||||
@@ -169,7 +167,6 @@ struct BlockFmhaFwdSplitKVPipelineProblem
|
|||||||
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
||||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
|
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
|
||||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||||
static constexpr bool kHasSink = Traits::kHasSink;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// extract tile size attributes to remove dependency on traits
|
// extract tile size attributes to remove dependency on traits
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||||
@@ -234,26 +233,10 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
clear_tile(l);
|
clear_tile(l);
|
||||||
|
|
||||||
const auto q_origin = q_dram_window.get_window_origin();
|
const auto q_origin = q_dram_window.get_window_origin();
|
||||||
|
const auto [seqlen_k_start, seqlen_k_end] =
|
||||||
|
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||||
|
|
||||||
const auto tile_range_result = [&mask, &q_origin]() {
|
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||||
if constexpr(kHasSink)
|
|
||||||
return mask.GetSinkTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
else
|
|
||||||
{
|
|
||||||
auto [start, end] =
|
|
||||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
return ck_tile::make_tuple(0, start, end);
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
|
||||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
|
||||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
|
||||||
|
|
||||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
|
||||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
|
||||||
const auto num_total_loop =
|
|
||||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
|
||||||
|
|
||||||
// check early exit if no work to do
|
// check early exit if no work to do
|
||||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||||
@@ -279,22 +262,22 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
auto k_dram_block_window =
|
auto k_dram_block_window =
|
||||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
k_dram_block_window_tmp.get_window_lengths(),
|
k_dram_block_window_tmp.get_window_lengths(),
|
||||||
{kv_load_start, 0});
|
{seqlen_k_start, 0});
|
||||||
|
|
||||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||||
auto bias_dram_window =
|
auto bias_dram_window =
|
||||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
bias_dram_block_window_tmp.get_window_lengths(),
|
bias_dram_block_window_tmp.get_window_lengths(),
|
||||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||||
|
|
||||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||||
randval_dram_block_window_tmp, kv_load_start);
|
randval_dram_block_window_tmp, seqlen_k_start);
|
||||||
|
|
||||||
auto v_dram_window =
|
auto v_dram_window =
|
||||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
v_dram_block_window_tmp.get_window_lengths(),
|
v_dram_block_window_tmp.get_window_lengths(),
|
||||||
{0, kv_load_start}, // TODO: hdim split?
|
{0, seqlen_k_start}, // TODO: hdim split?
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||||
@@ -467,11 +450,6 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
if(i_total_loops == 0)
|
|
||||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
}
|
|
||||||
move_tile_window(bias_dram_window, {0, kN0});
|
move_tile_window(bias_dram_window, {0, kN0});
|
||||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||||
{
|
{
|
||||||
@@ -482,34 +460,17 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
number<kN0>{});
|
number<kN0>{});
|
||||||
if(need_perpixel_check)
|
if(need_perpixel_check)
|
||||||
{
|
{
|
||||||
auto apply_mask = [&](auto&& mask_func) {
|
set_tile_if(
|
||||||
set_tile_if(
|
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||||
const auto row =
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
return !variant.LogitsMask(variant_params,
|
||||||
const auto col =
|
block_indices.batch_idx,
|
||||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
row,
|
||||||
return !mask_func(variant_params,
|
col,
|
||||||
block_indices.batch_idx,
|
block_indices.qo_head_idx,
|
||||||
row,
|
block_indices.kv_head_idx);
|
||||||
col,
|
|
||||||
block_indices.qo_head_idx,
|
|
||||||
block_indices.kv_head_idx);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
apply_mask([&](auto&&... args) {
|
|
||||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
apply_mask([&](auto&&... args) {
|
|
||||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -619,23 +580,11 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
|
|
||||||
if constexpr(kHasDropout)
|
if constexpr(kHasDropout)
|
||||||
{
|
{
|
||||||
|
// K and dropout use the same address in LDS, finish loading from k_lds_window by
|
||||||
|
// gemm_0 to reuse LDS.
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
|
|
||||||
|
|
||||||
index_t seq_offset = [&]() {
|
|
||||||
if constexpr(!kHasSink)
|
|
||||||
return seqlen_k_start + i_total_loops * kN0;
|
|
||||||
|
|
||||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
|
||||||
if(i_total_loops == num_sink_loop)
|
|
||||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
|
|
||||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
|
||||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
|
||||||
}();
|
|
||||||
|
|
||||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
|
||||||
}
|
}
|
||||||
|
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
@@ -687,14 +636,6 @@ struct BlockFmhaPipelineQRKSVS
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
// move K tile windows
|
// move K tile windows
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
if(i_total_loops == 0)
|
|
||||||
{
|
|
||||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
|
||||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||||
// tail
|
// tail
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||||
@@ -278,26 +277,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
clear_tile(l);
|
clear_tile(l);
|
||||||
|
|
||||||
__builtin_amdgcn_sched_barrier(0);
|
__builtin_amdgcn_sched_barrier(0);
|
||||||
const auto q_origin = q_dram_window.get_window_origin();
|
const auto q_origin = q_dram_window.get_window_origin();
|
||||||
const auto tile_range_result = [&mask, &q_origin]() {
|
const auto [seqlen_k_start, seqlen_k_end] =
|
||||||
if constexpr(kHasSink)
|
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||||
return mask.GetSinkTileRangeAlongX(
|
|
||||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
else
|
|
||||||
{
|
|
||||||
auto [start, end] =
|
|
||||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
|
||||||
return ck_tile::make_tuple(0, start, end);
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
|
||||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
|
||||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
|
||||||
|
|
||||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
|
||||||
const auto num_total_loop =
|
|
||||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
|
||||||
|
|
||||||
// check early exit if no work to do
|
// check early exit if no work to do
|
||||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||||
@@ -325,7 +309,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
auto k_dram_block_window =
|
auto k_dram_block_window =
|
||||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
k_dram_block_window_tmp.get_window_lengths(),
|
k_dram_block_window_tmp.get_window_lengths(),
|
||||||
{kv_load_start, 0});
|
{seqlen_k_start, 0});
|
||||||
|
|
||||||
auto k_dram_window = make_tile_window(
|
auto k_dram_window = make_tile_window(
|
||||||
k_dram_block_window.get_bottom_tensor_view(),
|
k_dram_block_window.get_bottom_tensor_view(),
|
||||||
@@ -348,16 +332,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
auto bias_dram_window =
|
auto bias_dram_window =
|
||||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
bias_dram_block_window_tmp.get_window_lengths(),
|
bias_dram_block_window_tmp.get_window_lengths(),
|
||||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||||
|
|
||||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||||
randval_dram_block_window_tmp, kv_load_start);
|
randval_dram_block_window_tmp, seqlen_k_start);
|
||||||
|
|
||||||
auto v_dram_window =
|
auto v_dram_window =
|
||||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||||
v_dram_block_window_tmp.get_window_lengths(),
|
v_dram_block_window_tmp.get_window_lengths(),
|
||||||
{0, kv_load_start}, // TODO: hdim split?
|
{0, seqlen_k_start}, // TODO: hdim split?
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
// prefetch K tile
|
// prefetch K tile
|
||||||
@@ -494,11 +478,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
if(i_total_loops == 0)
|
|
||||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
}
|
|
||||||
move_tile_window(bias_dram_window, {0, kN0});
|
move_tile_window(bias_dram_window, {0, kN0});
|
||||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||||
{
|
{
|
||||||
@@ -510,34 +489,17 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
|
|
||||||
if(need_perpixel_check)
|
if(need_perpixel_check)
|
||||||
{
|
{
|
||||||
auto apply_mask = [&](auto&& mask_func) {
|
set_tile_if(
|
||||||
set_tile_if(
|
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||||
const auto row =
|
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
return !variant.LogitsMask(variant_params,
|
||||||
const auto col =
|
block_indices.batch_idx,
|
||||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
row,
|
||||||
return !mask_func(variant_params,
|
col,
|
||||||
block_indices.batch_idx,
|
block_indices.qo_head_idx,
|
||||||
row,
|
block_indices.kv_head_idx);
|
||||||
col,
|
|
||||||
block_indices.qo_head_idx,
|
|
||||||
block_indices.kv_head_idx);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if constexpr(kHasSink)
|
|
||||||
{
|
|
||||||
apply_mask([&](auto&&... args) {
|
|
||||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
apply_mask([&](auto&&... args) {
|
|
||||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -685,21 +647,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
{
|
{
|
||||||
auto randval_ptr =
|
auto randval_ptr =
|
||||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||||
|
|
||||||
index_t seq_offset = [&]() {
|
|
||||||
if constexpr(!kHasSink)
|
|
||||||
return seqlen_k_start + i_total_loops * kN0;
|
|
||||||
|
|
||||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
|
||||||
if(i_total_loops == num_sink_loop)
|
|
||||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
|
|
||||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
|
||||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
|
||||||
}();
|
|
||||||
|
|
||||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
randval_ptr,
|
||||||
|
seqlen_k_start + i_total_loops * kN0,
|
||||||
|
p_compute,
|
||||||
|
randval_dram_window);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto p = [&]() {
|
const auto p = [&]() {
|
||||||
@@ -765,16 +717,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
|||||||
i_total_loops++;
|
i_total_loops++;
|
||||||
if(i_total_loops < num_total_loop)
|
if(i_total_loops < num_total_loop)
|
||||||
{
|
{
|
||||||
if constexpr(kHasSink)
|
// move K tile windows
|
||||||
{
|
|
||||||
if(i_total_loops == 0)
|
|
||||||
{
|
|
||||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
|
||||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||||
|
|
||||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||||
|
|
||||||
if constexpr(k1_loops >= 2 &&
|
if constexpr(k1_loops >= 2 &&
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
|||||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||||
static constexpr bool kHasUnevenSplits = true;
|
static constexpr bool kHasUnevenSplits = true;
|
||||||
static constexpr bool kHasSink = Problem::kHasSink;
|
|
||||||
|
|
||||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||||
|
|||||||
@@ -19,9 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
|||||||
bool kStoreLSE_,
|
bool kStoreLSE_,
|
||||||
bool kHasDropout_,
|
bool kHasDropout_,
|
||||||
bool kDoFp8StaticQuant_,
|
bool kDoFp8StaticQuant_,
|
||||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||||
bool kHasSink_ = false>
|
|
||||||
struct TileFmhaTraits
|
struct TileFmhaTraits
|
||||||
{
|
{
|
||||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||||
@@ -36,7 +35,6 @@ struct TileFmhaTraits
|
|||||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||||
@@ -66,9 +64,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
|||||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||||
bool kIsPagedKV_,
|
bool kIsPagedKV_,
|
||||||
bool kDoFp8StaticQuant_,
|
bool kDoFp8StaticQuant_,
|
||||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||||
bool kHasSink_ = false>
|
|
||||||
struct TileFmhaFwdPagedKVTraits
|
struct TileFmhaFwdPagedKVTraits
|
||||||
{
|
{
|
||||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||||
@@ -83,7 +80,6 @@ struct TileFmhaFwdPagedKVTraits
|
|||||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||||
@@ -98,8 +94,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
|||||||
bool kIsPagedKV_,
|
bool kIsPagedKV_,
|
||||||
bool kHasUnevenSplits_,
|
bool kHasUnevenSplits_,
|
||||||
bool kMergeNumHeadGroupsSeqLenQ_ = false,
|
bool kMergeNumHeadGroupsSeqLenQ_ = false,
|
||||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||||
bool kHasSink_ = false>
|
|
||||||
struct TileFmhaFwdSplitKVTraits
|
struct TileFmhaFwdSplitKVTraits
|
||||||
{
|
{
|
||||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||||
@@ -116,7 +111,6 @@ struct TileFmhaFwdSplitKVTraits
|
|||||||
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
||||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
|
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
|
||||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||||
static constexpr bool kHasSink = kHasSink_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||||
|
|||||||
Reference in New Issue
Block a user