mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Update README: FP8 is no longer experimental, document quantization parameters
Agent-Logs-Url: https://github.com/ROCm/composable_kernel/sessions/f07f07f0-9b53-4391-9807-a6a261768600 Co-authored-by: asleepzzz <4926646+asleepzzz@users.noreply.github.com>
This commit is contained in:
@@ -62,14 +62,17 @@ args:
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
-qscale n or 0, no scaling (default:n)
|
||||
1: per-tensor quantization.
|
||||
pt or 1, per-tensor scale
|
||||
bs or 2, block scale
|
||||
kvbs or 3, Q per-tensor, K/V per-page block scale
|
||||
mx or 4, microscaling (exclusively for mxfp8/mxfp4)
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
@@ -161,7 +164,17 @@ We support sequence padding and variable-length processing in both batch and gro
|
||||
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
## FP8 support
|
||||
FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. You can select fp8 precision by setting the arg `-prec=fp8` (or `fp8bf16`, `fp8fp32`) to the `tile_example_fmha_fwd`.
|
||||
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
The following quantization scale modes are available via `-qscale`:
|
||||
|
||||
| `-qscale` value | Description |
|
||||
|---|---|
|
||||
| `n` or `0` | No quantization scale (default) |
|
||||
| `pt` or `1` | Per-tensor quantization scale — a single scale factor is applied to the entire tensor |
|
||||
| `bs` or `2` | Per-block quantization scale — a scale factor is applied per block of elements |
|
||||
| `kvbs` or `3` | Q per-tensor + K/V per-page block scale — Q uses a single per-tensor scale, while K and V use per-page block scales |
|
||||
| `mx` or `4` | Microscaling (MX format), exclusively for `mxfp8` and `mxfp4` data types |
|
||||
|
||||
Currently only `-vlayout=r` (`seqlen*hdim` for V matrix) is supported for fp8 data types.
|
||||
|
||||
Reference in New Issue
Block a user