From ab682f6b1c3f07704d555db5d132c283b7895fc0 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Apr 2026 14:58:53 +0800 Subject: [PATCH] [CK_TILE] Refine FMHA Readme (#6003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates the FMHA README to document fp8 precision support more accurately, replacing the outdated "experimental" section and incomplete CLI arg descriptions. ## Changes - **`-prec` arg**: expanded supported values from `fp16/bf16/fp8/bf8` → `fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4` - **`-qscale` arg**: replaced single-line `1: per-tensor quantization` with all four modes: `pt/1`, `bs/2`, `kvbs/3`, `mx/4` - **FP8 support section**: replaced "FP8 experimental support" paragraph with: - Supported targets: gfx942/gfx950 + ROCm 6.0+ - Table distinguishing `fp8` / `fp8bf16` / `fp8fp32` by Q/K/V input type and output type - Table for all `-qscale` modes with descriptions - Note that `-vlayout=r` (`seqlen*hdim` for V) is the only supported layout for fp8 types
Original prompt Please open a PR against base branch `develop` in repository `ROCm/rocm-libraries` applying the following documentation updates within the composable kernel path. ## Scope Update the file: - `projects/composablekernel/example/ck_tile/01_fmha/README.md` ## Changes to apply Apply the combined edits described in the diffs below (two consecutive patches). Ensure the final file content includes **both** sets of changes. ### Patch 1 - In the CLI args section: - Update `-qscale` description lines to include: - `pt or 1, per-tensor scale` - `bs or 2, block scale` - `kvbs or 3, Q per-tensor, K/V per-page block scale` - `mx or 4, microscaling (exclusively for mxfp8/mxfp4)` - Update `-prec` supported data types from `fp16/bf16/fp8/bf8` to `fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4`. - Replace the existing "FP8 experimental support" section with an "FP8 support" section stating: - FP8 FMHA kernels supported on gfx942/gfx950 with ROCm 6.0+ - Precision selectable via `-prec=fp8` (or `fp8bf16`, `fp8fp32`) for `tile_example_fmha_fwd` - Add a table describing `-qscale` modes: - `n` or `0`: No quantization scale (default) - `pt` or `1`: Per-tensor quantization scale - `bs` or `2`: Per-block quantization scale - `kvbs` or `3`: Q per-tensor + K/V per-page block scale - `mx` or `4`: Microscaling (MX format), exclusively for `mxfp8` and `mxfp4` - Add/keep note that currently only `-vlayout=r` (`seqlen*hdim` for V matrix) is supported for fp8 data types. ### Patch 2 Further refine the "FP8 support" paragraph to explain the difference between `fp8`, `fp8bf16`, and `fp8fp32` via a table: | `-prec` value | Q/K/V input type | Output type | Description | |---|---|---|---| | `fp8` | fp8 | fp8 | Fully fp8: both inputs and output are in fp8 | | `fp8bf16` | fp8 | bf16 | Mixed precision: fp8 inputs, bf16 output — useful when the consumer expects a wider-range output format | | `fp8fp32` | fp8 | fp32 | Mixed precision: fp8 inputs, fp32 output — highest-precision output, suitable for debugging or further fp32 processing | Keep the rest of the `-qscale` table and the `-vlayout=r` limitation note. ## Notes - PR title must be: `[CK_TILE] Add fp8 in FMHA readme` - Ensure markdown formatting renders correctly (tables, code formatting). - Only modify the file listed above. The following is the prior conversation context from the user's chat exploration (may be truncated): User: 能幫我上這個pr嗎 在composable kernel裡的路徑 diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0b526f4e9fc..1627435863b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -62,14 +62,17 @@ args: -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) -qscale n or 0, no scaling (default:n) - 1: per-tensor quantization. + pt or 1, per-tensor scale + bs or 2, block scale + kvbs or 3, Q per-tensor, K/V per-page block scale + mx or 4, microscaling (exclusively for mxfp8/mxfp4) -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) -bias n or 0, no bias (default:n) e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s a(libi) or 2, alibi with 1*h. a:1, b*h - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) 't', top-left causal mask, 'b', bottom-r causal mask 't:l,r', top-left sliding window attn(swa) with FA style left right size @@ -161,7 +164,17 @@ We support sequence padding and variable-length processing in both batch and gro Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. -## FP8 experimental support -As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. +## FP8 support +FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. You can select fp8 precision by setting the arg `-prec=fp8` (or `fp8bf16`, `fp8fp32`) to the `tile_example_fmha_fwd`. -Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. +The following quantization scale modes are available via `-qscale`: + +| `-qscale` value | Description | +|---|---| +| `n` or `0` | No quantization sca...
*This pull request was created from Copilot chat.* > --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: asleepzzz <4926646+asleepzzz@users.noreply.github.com> Co-authored-by: asleepzzz --- example/ck_tile/01_fmha/README.md | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0b526f4e9f..2aaaa45a9a 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -62,14 +62,17 @@ args: -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) -qscale n or 0, no scaling (default:n) - 1: per-tensor quantization. + pt or 1, per-tensor scale + bs or 2, block scale + kvbs or 3, Q per-tensor, K/V per-page block scale, only in batch_prefill + mx or 4, microscaling (exclusively for mxfp8/mxfp4) -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) -bias n or 0, no bias (default:n) e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s a(libi) or 2, alibi with 1*h. a:1, b*h - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) 't', top-left causal mask, 'b', bottom-r causal mask 't:l,r', top-left sliding window attn(swa) with FA style left right size @@ -161,7 +164,23 @@ We support sequence padding and variable-length processing in both batch and gro Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. -## FP8 experimental support -As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. +## FP8 support +FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. Three fp8-based precision modes are available via `-prec`: -Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. +| `-prec` value | Q/K/V input type | Output type | Description | +|---|---|---|---| +| `fp8` | fp8 | fp8 | Fully fp8: both inputs and output are in fp8 | +| `fp8bf16` | fp8 | bf16 | Mixed precision: fp8 inputs, bf16 output — useful when the consumer expects a wider-range output format | +| `fp8fp32` | fp8 | fp32 | Mixed precision: fp8 inputs, fp32 output — highest-precision output, suitable for debugging or further fp32 processing | + +The following quantization scale modes are available via `-qscale`: + +| `-qscale` value | Description | +|---|---| +| `n` or `0` | No quantization scale (default) | +| `pt` or `1` | Per-tensor quantization scale — a single scale factor is applied to the entire tensor | +| `bs` or `2` | Per-block quantization scale — a scale factor is applied per block of elements | +| `kvbs` or `3` | Q per-tensor + K/V per-page block scale (batch_prefill only) | +| `mx` or `4` | Microscaling (MX format), exclusively for `mxfp8` and `mxfp4` data types | + +Currently only `-vlayout=r` (`seqlen*hdim` for V matrix) is supported for fp8 data types.