From 9ca5aca74d2c51295f25f9d591845143835f319a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 11 May 2024 00:03:39 +0800 Subject: [PATCH] [CK_TILE] fix some rand number init (#1287) * add random norm * normalized default to 0/3 * change squant->auto [ROCm/composable_kernel commit: fcba889ef461bb334e8f74ea465713f5b7611855] --- example/ck_tile/01_fmha/README.md | 11 ++-- example/ck_tile/01_fmha/fmha_fwd.cpp | 87 ++++++++++++++++++---------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index fd5690a795..a3248e2a5e 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -44,9 +44,9 @@ args: -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) - -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) - 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, - scale_o according to range_q, range_k, range_v, range_p, range_o + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) @@ -64,8 +64,11 @@ args: -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) - -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 686633bb2d..74cb3657e6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[]) .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") - .insert( - "squant", - "0", - "if using static quantization fusion or not. 0: original flow(not prefered)\n" - "1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" - "scale_o according to range_q, range_k, range_v, range_p, range_o") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") .insert("iperm", "1", "permute input\n" @@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[]) .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") - .insert( - "init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") + .insert("init", + "uf", + "init method. ui, uniform random int, ni, normalized random int\n" + "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " + "quantization") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -107,7 +112,7 @@ auto create_args(int argc, char* argv[]) // different threshold for different dtype template -auto get_elimit(int /*init_method*/) +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; double atol = 1e-3; @@ -115,9 +120,15 @@ auto get_elimit(int /*init_method*/) } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string init_method) { - if(init_method == 0) + if(init_method == "ui" || init_method == "ni") + { + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); + } + else if(init_method == "nf") { double rtol = 1e-2; double atol = 1e-2; @@ -132,9 +143,9 @@ auto get_elimit(int init_method) } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string init_method) { - if(init_method == 0) + if(init_method == "ui" || init_method == "ni") { unsigned max_rounding_point_distance = 0; double atol = 2e-3; @@ -182,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - bool squant = arg_parser.get_bool("squant"); - if constexpr(!std::is_same_v) - { - if(squant) + std::string squant_str = arg_parser.get_str("squant"); + bool squant = [&]() { + if(squant_str == "auto") { - std::cerr << "static quantization only support fp8 for now" << std::endl; - return false; + if(data_type == "fp8") + return true; + else + return false; } - } + else + return atoi(squant_str.c_str()) != 0 ? true : false; + }(); float range_q = arg_parser.get_float("range_q"); float range_k = arg_parser.get_float("range_k"); @@ -217,7 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias_info bias = bias_info::decode(arg_parser.get_str("bias")); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); - int init_method = arg_parser.get_int("init"); + std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); if(*seed == 0) { @@ -319,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - if(init_method == 0) + if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } - else if(init_method == 1) + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "uf" || init_method == "1") { ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } - else if(init_method == 2) + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } + else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); ck_tile::FillTrigValue{}(v_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == 3) // suitable for fp8 quantization + else if(init_method == "ufq" || init_method == "uf:q" || + init_method == "3") // suitable for fp8 quantization { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host);