mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Add input fp8 and output bf16 attention (#2726)
* change host using fp16 to check * fp8 to fp8 compare * rewrite input parameters * add not squant * remove some output code * for scale = 1 * format * saturates only for fp8 * add fp8bf16 data type * add fp8bf16 data type * fix test fp8 code * add run_fp8bf16_tests * change fmha fwd example parameter(adding fp8bf16) * Support fp8bf16 for Aiter * Support aiter fp8bf16 in c++ * fix comment about fp8 in readme.md * add fp8fp32 * add fp8fp32 test * remove range_q etc. * format * fix test parameters about squant and fmha example input fp8bf16 fp8fp32 data type * add fp8bf16 to data_type function * change colmajor to rowmajor in test_ck_tile_fmha_fwd_fp8 * format * reset atol for fp8 * fix bug for atol --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
"note when squant=1, this value will be modified")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
"calculate scale_s, scale_p, scale_o auto")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[])
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization")
|
||||
"\n tf or 2 - trig float\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
@@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
range_q,
|
||||
range_k,
|
||||
range_v,
|
||||
range_p,
|
||||
range_o,
|
||||
squant,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
@@ -237,6 +221,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8fp32")
|
||||
{
|
||||
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user