diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index c7e9f441b9..8b7020ebec 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -124,7 +124,7 @@ auto create_args(int argc, char* argv[]) "# of splits for key/value. 0 to determine actual number by heuristic") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") - .insert("rotary_interleaved", "1", "weather to apply interleaving RoPE") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -575,14 +575,18 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "uf" || init_method == "1") @@ -598,14 +602,18 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); } else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } else if(init_method == "ufq" || init_method == "uf:q" || @@ -613,7 +621,9 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);