Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -45,18 +45,12 @@ auto create_args(int argc, char* argv[])
"must be greater than or equal to s_k")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s",
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
"note when squant=1, this value will be modified")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("squant",
"auto",
"if using static quantization fusion or not. auto: fp8 will default use squant, "
"other will not\n"
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
"P and O.\n"
"calculate scale_s, scale_p, scale_o auto")
.insert("iperm",
"1",
"permute input\n"
@@ -87,7 +81,8 @@ auto create_args(int argc, char* argv[])
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
"\n uf or 1 - uniform random float\n nf - normalized random float"
"\n tf or 2 - trig float\n")
"\n tf or 2 - trig float"
"\n tf or 3 - uniform random float, min max is the max of the type\n")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
@@ -152,6 +147,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
std::string bias_str = arg_parser.get_str("bias");
std::string qscale_str = arg_parser.get_str("qscale");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
@@ -162,13 +158,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool squant = [&]() {
if(arg_parser.get_str("squant") == "auto")
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
else
return arg_parser.get_bool("squant");
}();
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
@@ -208,7 +197,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
drop_offset,
drop_prefs,
mask_str,
squant,
qscale_str,
is_rotary_interleaved,
num_splits,
init_method,
@@ -239,10 +228,6 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;