Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -17,12 +17,12 @@ The executables reside in `bin` subdirectory of the build directory.
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`.
> [!NOTE]
> `cmake-ck-dev.sh` is a CMake wrapper.
> [!NOTE]
> `cmake-ck-dev.sh` is a CMake wrapper.
>
> The first argument is the path to composable_kernel sources.
>
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
>
> The remaining arguments are optional and are passed through to CMake.
> E.g. `-G Ninja` specifies ninja as the build system.
@@ -61,15 +61,8 @@ args:
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-qscale n or 0, no scaling (default:n)
1: per-tensor quantization.
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
@@ -104,7 +97,7 @@ args:
Comma-separated list of length 'b'. If empty, no override
```
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## Padding Examples