[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)

[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-03-11 10:00:52 +00:00
committed by assistant-librarian[bot]
parent c85c272c39
commit 2312eef6c3
29 changed files with 2167 additions and 356 deletions

View File

@@ -48,8 +48,12 @@ auto create_args(int argc, char* argv[])
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n")
"quant scale:\n"
" n or 0, no scale\n"
" pt or 1, per-tensor scale\n"
" bs or 2, block scale\n"
" kvbs or 3, Q per-tensor, K/V per-page block scale\n"
" mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("iperm",
"1",
@@ -61,7 +65,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -231,6 +235,10 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
@@ -239,6 +247,14 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp8")
{
return run<FmhaFwdMxFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp4")
{
return run<FmhaFwdMxFp4>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}