mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Support fp8 dynamic quantization for fmha (#3206)
* Support qscale for dynamic quant, remove static quant * Support hdim=256 * Remove bias test case for fp8 --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -17,12 +17,12 @@ The executables reside in `bin` subdirectory of the build directory.
|
||||
|
||||
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`.
|
||||
|
||||
> [!NOTE]
|
||||
> `cmake-ck-dev.sh` is a CMake wrapper.
|
||||
> [!NOTE]
|
||||
> `cmake-ck-dev.sh` is a CMake wrapper.
|
||||
>
|
||||
> The first argument is the path to composable_kernel sources.
|
||||
>
|
||||
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
|
||||
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
|
||||
>
|
||||
> The remaining arguments are optional and are passed through to CMake.
|
||||
> E.g. `-G Ninja` specifies ninja as the build system.
|
||||
@@ -61,15 +61,8 @@ args:
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
note when squant=1, this value will be modified by range_q/k
|
||||
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
|
||||
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-qscale n or 0, no scaling (default:n)
|
||||
1: per-tensor quantization.
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
@@ -104,7 +97,7 @@ args:
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
```
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## Padding Examples
|
||||
|
||||
@@ -63,6 +63,16 @@ def get_mask_check_map(mask: str):
|
||||
return None
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
|
||||
@@ -24,6 +24,8 @@ from codegen.cpp_symbol_map import (
|
||||
FWD_DTYPE_MAP,
|
||||
BIAS_MAP,
|
||||
get_mask_map,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
@@ -64,7 +66,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
@@ -103,7 +105,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
@@ -190,9 +192,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -232,7 +234,7 @@ class FmhaFwdApiTrait:
|
||||
bias: str #
|
||||
lse: str #
|
||||
dropout: str
|
||||
squant: str #
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
@@ -245,7 +247,7 @@ class FmhaFwdApiTrait:
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -341,7 +343,7 @@ class FmhaFwdPipeline:
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_dropout: str #
|
||||
F_squant: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_trload: str # true/false
|
||||
@@ -406,10 +408,10 @@ class FmhaFwdPipeline:
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_squant == "t":
|
||||
n += "_squant"
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nsquant"
|
||||
n += "_nqscale"
|
||||
|
||||
if self.F_trload == "t":
|
||||
n += "_trload"
|
||||
@@ -462,7 +464,8 @@ class FmhaFwdApiPool:
|
||||
F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
@@ -580,7 +583,7 @@ class FmhaFwdKernel:
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
@@ -623,7 +626,7 @@ class FmhaFwdKernel:
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
@@ -695,7 +698,7 @@ class KernelComponentFactoryGfx9:
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ["fp32"]:
|
||||
squant = "f"
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
@@ -704,11 +707,11 @@ class KernelComponentFactoryGfx9:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
elif dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
@@ -718,28 +721,31 @@ class KernelComponentFactoryGfx9:
|
||||
["t", "f"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
|
||||
elif dtype in ["fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"],
|
||||
["no", "pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "bf8"]:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -756,7 +762,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
dtype, hdim, hdim_v, receipt, mask_impl
|
||||
)
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
@@ -772,8 +778,8 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
and dropout == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -810,7 +816,7 @@ class KernelComponentFactoryGfx12:
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
@@ -819,15 +825,15 @@ class KernelComponentFactoryGfx12:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -932,7 +938,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
if not cond:
|
||||
continue
|
||||
@@ -941,7 +947,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
cond &= pipeline.F_logits == "f"
|
||||
@@ -953,7 +959,7 @@ def get_fwd_blobs(
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
@@ -962,7 +968,7 @@ def get_fwd_blobs(
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
@@ -970,13 +976,13 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ["fp8", "fp8bf16", "fp8fp32"]
|
||||
cond = dtype in ["fp8bf16", "fp8fp32"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= hdim == 128
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
@@ -45,18 +45,12 @@ auto create_args(int argc, char* argv[])
|
||||
"must be greater than or equal to s_k")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("qscale",
|
||||
"n",
|
||||
"n or 0, no scale\n"
|
||||
"pt or 1, per-tensor scale\n")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o auto")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -87,7 +81,8 @@ auto create_args(int argc, char* argv[])
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float\n")
|
||||
"\n tf or 2 - trig float"
|
||||
"\n tf or 3 - uniform random float, min max is the max of the type\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -152,6 +147,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
std::string bias_str = arg_parser.get_str("bias");
|
||||
std::string qscale_str = arg_parser.get_str("qscale");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
@@ -162,13 +158,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
bool squant = [&]() {
|
||||
if(arg_parser.get_str("squant") == "auto")
|
||||
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
|
||||
else
|
||||
return arg_parser.get_bool("squant");
|
||||
}();
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
@@ -208,7 +197,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
squant,
|
||||
qscale_str,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
init_method,
|
||||
@@ -239,10 +228,6 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "bias.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "quant.hpp"
|
||||
#include "rotary.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
@@ -178,6 +179,9 @@ struct fmha_fwd_args
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
@@ -237,9 +241,6 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
float logits_soft_cap;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
@@ -581,6 +582,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -593,8 +597,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
@@ -625,6 +627,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -635,8 +640,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
@@ -1125,7 +1128,7 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
@@ -1150,7 +1153,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
@@ -1341,7 +1344,7 @@ struct fmha_fwd_traits
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
quant_scale_enum qscale_type;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -178,7 +178,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
bool squant,
|
||||
std::string qscale_str,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t num_splits,
|
||||
std::string init_method,
|
||||
@@ -380,6 +380,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask_info mask =
|
||||
mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
|
||||
|
||||
quant_scale_info qscale = quant_scale_info::decode(qscale_str);
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -572,6 +574,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
@@ -592,7 +599,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
||||
? std::array<ck_tile::index_t, 1>{batch}
|
||||
: std::array<ck_tile::index_t, 1>{1});
|
||||
float max_o = 5.0;
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -640,6 +646,23 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == "3")
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float bias_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<BiasDataType>::max());
|
||||
|
||||
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(
|
||||
knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(
|
||||
vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{
|
||||
-bias_dtype_max, bias_dtype_max, next_seed()}(bias_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
||||
@@ -658,6 +681,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
|
||||
float qkv_max = 3.f;
|
||||
q_descale_host(0) = qkv_max / q_dtype_max;
|
||||
k_descale_host(0) = qkv_max / k_dtype_max;
|
||||
v_descale_host(0) = qkv_max / v_dtype_max;
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
@@ -667,6 +702,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
@@ -702,81 +740,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
if(squant)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
// Q tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::min());
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = q_dtype_max / max_value;
|
||||
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// K tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::min());
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
float scale = k_dtype_max / max_value;
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<KDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// V tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::min());
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = k_dtype_max / max_value;
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<VDataType>(val * scale);
|
||||
});
|
||||
|
||||
scale_o = (1.0 / p_dtype_max) / scale;
|
||||
}
|
||||
|
||||
scale_p = p_dtype_max;
|
||||
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
{
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
scale_o = scale_o * o_dtype_max / max_o;
|
||||
}
|
||||
}
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
q_descale_buf.ToDevice(q_descale_host.data());
|
||||
k_descale_buf.ToDevice(k_descale_host.data());
|
||||
v_descale_buf.ToDevice(v_descale_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
@@ -816,7 +788,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
<< (seqlen_kpads[0] < 0 ? ""
|
||||
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", qscale:" << qscale
|
||||
<< ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c");
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(0 < rotary_dim)
|
||||
@@ -908,11 +880,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
traits.do_fp8_static_quant = squant;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
traits.qscale_type = qscale.type;
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
|
||||
std::decay_t<decltype(traits)>>)
|
||||
@@ -1055,8 +1027,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
args.scale_p = scale_p;
|
||||
args.scale_o = scale_o;
|
||||
|
||||
args.logits_soft_cap = logits_soft_cap;
|
||||
|
||||
@@ -1076,6 +1046,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_randval = stride_randval;
|
||||
@@ -1351,23 +1325,34 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
constexpr bool supports_qscale = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
|
||||
|
||||
float scale_s_host = scale_s;
|
||||
float scale_p_host = 1.0f;
|
||||
float scale_o_host = 1.0f;
|
||||
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
{
|
||||
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
|
||||
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
scale_o_host = v_descale_host(0) / scale_p_host;
|
||||
}
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(supports_squant)
|
||||
return ck_tile::scales{scale_p};
|
||||
if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_p_host};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_squant)
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else if constexpr(supports_squant)
|
||||
return ck_tile::scales{scale_o};
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
@@ -1573,7 +1558,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
ck_tile::scales(scale_s_host));
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
@@ -1818,7 +1803,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
squant,
|
||||
qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: "pertensor",
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
|
||||
53
example/ck_tile/01_fmha/quant.hpp
Normal file
53
example/ck_tile/01_fmha/quant.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
{
|
||||
quant_scale_enum type;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == quant_scale_enum::no_scale)
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
{
|
||||
quant_scale_info info{quant_scale_enum::no_scale};
|
||||
if(str == "n" || str == "0")
|
||||
{
|
||||
info.type = quant_scale_enum::no_scale;
|
||||
}
|
||||
else if(str == "pt" || str == "1")
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
||||
{
|
||||
qsi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@@ -73,52 +73,39 @@ run_fp16_bf16_tests() {
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 128 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8fp32 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_fp16_appendkv_tests() {
|
||||
@@ -133,7 +120,7 @@ run_fp16_appendkv_tests() {
|
||||
|
||||
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
@@ -249,7 +236,6 @@ set -x
|
||||
run_fp16_bf16_tests
|
||||
run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
|
||||
@@ -598,6 +598,8 @@ struct HostTensor
|
||||
|
||||
typename Data::size_type size() const { return mData.size(); }
|
||||
|
||||
T max() const { return *std::max_element(mData.begin(), mData.end()); }
|
||||
|
||||
// return a slice of this tensor
|
||||
// for simplicity we just copy the data and return a new tensor
|
||||
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
struct BlockAttentionQuantScaleEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::NO_SCALE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR>
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <string>
|
||||
@@ -36,6 +37,7 @@ struct FmhaFwdKernel
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
@@ -54,7 +56,7 @@ struct FmhaFwdKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
@@ -112,7 +114,8 @@ struct FmhaFwdKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
|
||||
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -204,10 +207,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8StaticQuantKargs
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -285,7 +289,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -309,7 +315,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -339,6 +347,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -349,8 +360,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -408,7 +417,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
@@ -440,10 +449,11 @@ struct FmhaFwdKernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -483,6 +493,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -493,8 +506,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -530,6 +541,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -540,8 +554,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -580,6 +592,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -590,8 +605,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -627,6 +640,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -637,8 +653,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -676,6 +690,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -688,8 +705,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -741,7 +756,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
@@ -772,10 +787,11 @@ struct FmhaFwdKernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -818,6 +834,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -830,8 +849,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -861,6 +878,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -873,8 +893,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -907,6 +925,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -919,8 +940,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -950,6 +969,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -962,8 +984,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -1527,14 +1547,24 @@ struct FmhaFwdKernel
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_s = kargs.scale_s * q_descale * k_descale;
|
||||
float scale_p =
|
||||
ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{kargs.scale_o});
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
return ck_tile::scales{kargs.scale_o};
|
||||
return ck_tile::scales{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
@@ -1546,13 +1576,13 @@ struct FmhaFwdKernel
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
|
||||
@@ -60,7 +60,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -18,7 +19,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaTraits
|
||||
@@ -32,7 +33,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
@@ -625,7 +625,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& qscale,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
@@ -650,7 +650,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
ADD_KEY_VALUE("scale_s", scale_s);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("qscale", qscale);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
|
||||
@@ -9,10 +9,10 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(TEST_NAME "test_ck_tile_fmha")
|
||||
|
||||
function(add_gtest_fwd test_group)
|
||||
set(V_TYPES "fp16" "bf16" "fp8" "fp32")
|
||||
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32")
|
||||
set(CPP_TYPE_fp16 "FmhaFwdFp16")
|
||||
set(CPP_TYPE_bf16 "FmhaFwdBf16")
|
||||
set(CPP_TYPE_fp8 "FmhaFwdFp8")
|
||||
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
|
||||
set(CPP_TYPE_fp32 "FmhaFwdFp32")
|
||||
|
||||
set(all_tests)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifndef DataTypeConfig
|
||||
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32
|
||||
#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8Bf16 / FmhaFwdFp32
|
||||
#endif
|
||||
|
||||
using ::testing::Bool;
|
||||
@@ -39,13 +39,14 @@ struct TestConfigs
|
||||
std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp8>
|
||||
struct TestConfigs<FmhaFwdFp8Bf16>
|
||||
{
|
||||
static constexpr auto HDimValues =
|
||||
std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
@@ -53,13 +54,14 @@ struct TestConfigs<FmhaFwdFp8>
|
||||
static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = true;
|
||||
static constexpr auto qscale_str = "pt";
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
|
||||
// return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp32>
|
||||
{
|
||||
@@ -76,7 +78,7 @@ struct TestConfigs<FmhaFwdFp32>
|
||||
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{true};
|
||||
static constexpr bool squant = false;
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
@@ -87,7 +89,7 @@ static auto SplitKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::SplitKV
|
||||
static auto AppendKVHDimValues = ValuesIn(TestConfigs<DataTypeConfig>::AppendKVHDimValues);
|
||||
static auto ModeValues = ValuesIn(TestConfigs<DataTypeConfig>::ModeValues);
|
||||
static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowmajorValues);
|
||||
constexpr bool squant = TestConfigs<DataTypeConfig>::squant;
|
||||
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
|
||||
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
|
||||
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
|
||||
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
|
||||
@@ -203,7 +205,7 @@ TEST_P(AllLong, DataTypeConfig)
|
||||
1024, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -247,7 +249,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
"0", // mask
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
init_method,
|
||||
@@ -291,7 +293,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
|
||||
0,
|
||||
false,
|
||||
"0",
|
||||
squant,
|
||||
qscale_str,
|
||||
true,
|
||||
2, // num_splits (>1 triggers splitkv)
|
||||
init_method,
|
||||
@@ -334,7 +336,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
|
||||
0,
|
||||
false,
|
||||
"0",
|
||||
squant,
|
||||
qscale_str,
|
||||
true,
|
||||
1,
|
||||
init_method,
|
||||
@@ -403,7 +405,7 @@ TEST_P(HDimPadding, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -463,7 +465,7 @@ TEST_P(ElementwiseBias, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -522,7 +524,7 @@ TEST_P(Alibi, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -583,7 +585,7 @@ TEST_P(Dropout, DataTypeConfig)
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -648,7 +650,7 @@ TEST_P(PagedKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -719,7 +721,7 @@ TEST_P(SplitKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
num_splits, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -796,7 +798,7 @@ TEST_P(AppendKV, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
false, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -818,7 +820,7 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
AppendKVRoPE,
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>),
|
||||
AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
@@ -869,7 +871,7 @@ TEST_P(AppendKVRoPE, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
is_rotary_interleaved, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -1105,7 +1107,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPadd
|
||||
|
||||
TEST_P(PaddingCases, DataTypeConfig)
|
||||
{
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
{
|
||||
GTEST_SKIP() << "Skip for fp8";
|
||||
}
|
||||
@@ -1162,7 +1164,7 @@ TEST_P(PaddingCases, DataTypeConfig)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
squant,
|
||||
qscale_str,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
|
||||
Reference in New Issue
Block a user