mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] fix some rand number init (#1287)
* add random norm
* normalized default to 0/3
* change squant->auto
[ROCm/composable_kernel commit: fcba889ef4]
This commit is contained in:
@@ -44,9 +44,9 @@ args:
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
|
||||
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
|
||||
scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
@@ -64,8 +64,11 @@ args:
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
|
||||
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
|
||||
|
||||
@@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert(
|
||||
"squant",
|
||||
"0",
|
||||
"if using static quantization fusion or not. 0: original flow(not prefered)\n"
|
||||
"1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n"
|
||||
"scale_o according to range_q, range_k, range_v, range_p, range_o")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("lse", "0", "0 not store lse, 1 store lse")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert(
|
||||
"init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method. ui, uniform random int, ni, normalized random int\n"
|
||||
"uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, "
|
||||
"quantization")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -107,7 +112,7 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(int /*init_method*/)
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
@@ -115,9 +120,15 @@ auto get_elimit(int /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
@@ -132,9 +143,9 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(int init_method)
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
@@ -182,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale_s == .0f)
|
||||
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
||||
|
||||
bool squant = arg_parser.get_bool("squant");
|
||||
if constexpr(!std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
{
|
||||
if(squant)
|
||||
std::string squant_str = arg_parser.get_str("squant");
|
||||
bool squant = [&]() {
|
||||
if(squant_str == "auto")
|
||||
{
|
||||
std::cerr << "static quantization only support fp8 for now" << std::endl;
|
||||
return false;
|
||||
if(data_type == "fp8")
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
return atoi(squant_str.c_str()) != 0 ? true : false;
|
||||
}();
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
@@ -217,7 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
int init_method = arg_parser.get_int("init");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
@@ -319,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == 3) // suitable for fp8 quantization
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
init_method == "3") // suitable for fp8 quantization
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
|
||||
|
||||
Reference in New Issue
Block a user