From f053ae2b5b3eb499b0bfbc54cae8b3767c388cc2 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Jul 2024 07:12:06 +0000 Subject: [PATCH] Add missing init code --- example/ck_tile/01_fmha/fmha_fwd.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index c7e9f441b9..8b7020ebec 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -124,7 +124,7 @@ auto create_args(int argc, char* argv[]) "# of splits for key/value. 0 to determine actual number by heuristic") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") - .insert("rotary_interleaved", "1", "weather to apply interleaving RoPE") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -575,14 +575,18 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "uf" || init_method == "1") @@ -598,14 +602,18 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); } else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } else if(init_method == "ufq" || init_method == "uf:q" || @@ -613,7 +621,9 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);