diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index b6287245a0..1227724d40 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -750,9 +750,21 @@ fwd_result fmha_fwd_run(mode_enum mode, } else if(qscale.type == quant_scale_enum::blockscale) { - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float qkv_max = 3.f; + float max_descale_q = qkv_max / q_dtype_max; + float max_descale_k = qkv_max / k_dtype_max; + float max_descale_v = qkv_max / v_dtype_max; + + ck_tile::FillUniformDistribution{max_descale_q * 0.8f, max_descale_q, next_seed()}( + q_descale_host); + ck_tile::FillUniformDistribution{max_descale_k * 0.8f, max_descale_k, next_seed()}( + k_descale_host); + ck_tile::FillUniformDistribution{max_descale_v * 0.8f, max_descale_v, next_seed()}( + v_descale_host); } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);