[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)

[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-03-11 10:00:52 +00:00
committed by assistant-librarian[bot]
parent c85c272c39
commit 2312eef6c3
29 changed files with 2167 additions and 356 deletions

View File

@@ -9,6 +9,8 @@ FWD_DTYPE_MAP = {
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32",
"mxfp8": "FmhaFwdMxFp8",
"mxfp4": "FmhaFwdMxFp4",
}
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
@@ -79,6 +81,7 @@ QSCALE_MAP = {
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}
QSCALE_CHECK_MAP = {
@@ -86,6 +89,7 @@ QSCALE_CHECK_MAP = {
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"mx": "quant_scale_enum::mx",
}
BIAS_MAP = {

View File

@@ -38,6 +38,8 @@ DTYPE_BITS = {
"fp8bf16": 8,
"fp8fp32": 8,
"bf8": 8,
"mxfp8": 8,
"mxfp4": 4,
}
K0_MAX_SUBMAX_MAP = {
@@ -836,7 +838,8 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
def check_hdim_tile(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype != "fp32":
# FIX: too confusing that it has to know about mx types
if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"):
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
(
@@ -966,8 +969,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
return {
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@@ -1035,9 +1036,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass
return pipelines
@@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950(
):
arch = ArchTrait("gfx950")
_DT_MXFP8 = ("mxfp8",)
_DT_MXFP4 = ("mxfp4",)
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return (
KernelComponentFactoryGfx9.supported_dtypes()
+ cls._DT_MXFP8
+ cls._DT_MXFP4
)
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
@@ -1054,6 +1063,18 @@ class KernelComponentFactoryGfx950(
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
elif dtype in cls._DT_MXFP8:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
elif dtype in cls._DT_MXFP4:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
return result
@classmethod
@@ -1091,6 +1112,19 @@ class KernelComponentFactoryGfx950(
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
# no need dropout kernels
lse = "t"
dropout = "f"
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["mx"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
return pipelines