From ee4e2167165e4050ca265ea0b11d15a54395e2e4 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 29 Jan 2026 04:37:15 +0800 Subject: [PATCH] Fix block scale init value (#3666) * Make blockscale descale range adaptive to data type max value * format [ROCm/composable_kernel commit: 654bec3362e825c27f0374e9e4f4e5b970e0f86f] --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index b6287245a0..1227724d40 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -750,9 +750,21 @@ fwd_result fmha_fwd_run(mode_enum mode, } else if(qscale.type == quant_scale_enum::blockscale) { - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float qkv_max = 3.f; + float max_descale_q = qkv_max / q_dtype_max; + float max_descale_k = qkv_max / k_dtype_max; + float max_descale_v = qkv_max / v_dtype_max; + + ck_tile::FillUniformDistribution{max_descale_q * 0.8f, max_descale_q, next_seed()}( + q_descale_host); + ck_tile::FillUniformDistribution{max_descale_k * 0.8f, max_descale_k, next_seed()}( + k_descale_host); + ck_tile::FillUniformDistribution{max_descale_v * 0.8f, max_descale_v, next_seed()}( + v_descale_host); } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);