mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Add missing init code
This commit is contained in:
@@ -124,7 +124,7 @@ auto create_args(int argc, char* argv[])
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
.insert("rotary_interleaved", "1", "weather to apply interleaving RoPE")
|
||||
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -575,14 +575,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
@@ -598,14 +602,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(knew_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
@@ -613,7 +621,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
|
||||
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
|
||||
|
||||
Reference in New Issue
Block a user