Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -24,6 +24,8 @@ from codegen.cpp_symbol_map import (
FWD_DTYPE_MAP,
BIAS_MAP,
get_mask_map,
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
@@ -64,7 +66,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_qscale},
{F_occupancy},
{F_skip}>;
@@ -103,7 +105,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
template<>
float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a)
@@ -190,9 +192,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@@ -232,7 +234,7 @@ class FmhaFwdApiTrait:
bias: str #
lse: str #
dropout: str
squant: str #
qscale: str #
spad: str
skpad: str
dpad: str
@@ -245,7 +247,7 @@ class FmhaFwdApiTrait:
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)
@property
@@ -341,7 +343,7 @@ class FmhaFwdPipeline:
F_bias: str # true/false
F_lse: str #
F_dropout: str #
F_squant: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_trload: str # true/false
@@ -406,10 +408,10 @@ class FmhaFwdPipeline:
else:
n += "_nskip"
if self.F_squant == "t":
n += "_squant"
if self.F_qscale != "no":
n += f"_{self.F_qscale}"
else:
n += "_nsquant"
n += "_nqscale"
if self.F_trload == "t":
n += "_trload"
@@ -462,7 +464,8 @@ class FmhaFwdApiPool:
F_dropout=BOOL_MAP[trait.dropout],
F_skip=BOOL_MAP[trait.skip],
F_trload=BOOL_MAP[trait.tr_load],
F_squant=BOOL_MAP[trait.squant],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
@@ -580,7 +583,7 @@ class FmhaFwdKernel:
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
@@ -623,7 +626,7 @@ class FmhaFwdKernel:
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
qscale=self.F_pipeline.F_qscale,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
@@ -695,7 +698,7 @@ class KernelComponentFactoryGfx9:
# TODO: how to design this more generic?
pipelines = []
if dtype in ["fp32"]:
squant = "f"
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
@@ -704,11 +707,11 @@ class KernelComponentFactoryGfx9:
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
elif dtype in ["fp16", "bf16"]:
squant = "f"
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
@@ -718,28 +721,31 @@ class KernelComponentFactoryGfx9:
["t", "f"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip
elif dtype in ["fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
for logits, qscale, mask, bias in itertools.product(
["f"],
["no", "pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8fp16", "bf8"]:
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
None
else:
@@ -756,7 +762,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
dtype, hdim, hdim_v, receipt, mask_impl
)
if dtype in ["fp16", "bf16"]:
squant = "f"
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
@@ -772,8 +778,8 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
and dropout == "f"
and skip == "f"
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip
return pipelines
@@ -810,7 +816,7 @@ class KernelComponentFactoryGfx12:
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
pipelines = []
if dtype in ["fp16", "bf16"]:
squant = "f"
qscale = "no"
for logits, mask, bias, lse, dropout, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
@@ -819,15 +825,15 @@ class KernelComponentFactoryGfx12:
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
for logits, qscale, mask, bias in itertools.product(
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
else:
assert False
return pipelines
@@ -932,7 +938,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
cond &= pipeline.F_skip == "f"
if not cond:
continue
@@ -941,7 +947,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
cond &= mode == "batch"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_logits == "f"
@@ -953,7 +959,7 @@ def get_fwd_blobs(
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128
cond &= hdim == 128 or hdim == 256
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
@@ -962,7 +968,7 @@ def get_fwd_blobs(
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128
cond &= hdim == 128 or hdim == 256
if not cond:
continue
# aiter::mha_fwd C++ api integration
@@ -970,13 +976,13 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128
cond &= hdim == 128 or hdim == 256
if not cond:
continue
elif receipt == 888:
cond = dtype in ["fp8", "fp8bf16", "fp8fp32"]
cond = dtype in ["fp8bf16", "fp8fp32"]
cond &= pipeline.F_vlayout == "row"
cond &= hdim == 128
cond &= hdim == 128 or hdim == 256
if not cond:
continue