mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c85c272c39
commit
2312eef6c3
@@ -9,6 +9,8 @@ FWD_DTYPE_MAP = {
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32",
|
||||
"mxfp8": "FmhaFwdMxFp8",
|
||||
"mxfp4": "FmhaFwdMxFp4",
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
|
||||
@@ -79,6 +81,7 @@ QSCALE_MAP = {
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
@@ -86,6 +89,7 @@ QSCALE_CHECK_MAP = {
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
"mx": "quant_scale_enum::mx",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -38,6 +38,8 @@ DTYPE_BITS = {
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
"mxfp8": 8,
|
||||
"mxfp4": 4,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
@@ -836,7 +838,8 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
def check_hdim_tile(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
if problem_ctx.dtype != "fp32":
|
||||
# FIX: too confusing that it has to know about mx types
|
||||
if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"):
|
||||
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
|
||||
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
|
||||
(
|
||||
@@ -966,8 +969,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
return {
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@@ -1035,9 +1036,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950(
|
||||
):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
_DT_MXFP8 = ("mxfp8",)
|
||||
_DT_MXFP4 = ("mxfp4",)
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return (
|
||||
KernelComponentFactoryGfx9.supported_dtypes()
|
||||
+ cls._DT_MXFP8
|
||||
+ cls._DT_MXFP4
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
@@ -1054,6 +1063,18 @@ class KernelComponentFactoryGfx950(
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
elif dtype in cls._DT_MXFP8:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in cls._DT_MXFP4:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -1091,6 +1112,19 @@ class KernelComponentFactoryGfx950(
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
|
||||
# no need dropout kernels
|
||||
lse = "t"
|
||||
dropout = "f"
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["mx"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user