Fix block scale init value (#3666)

* Make blockscale descale range adaptive to data type max value

* format
This commit is contained in:
ltqin
2026-01-29 04:37:15 +08:00
committed by GitHub
parent 42048bdb7d
commit 654bec3362

View File

@@ -750,9 +750,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
else if(qscale.type == quant_scale_enum::blockscale)
{
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float qkv_max = 3.f;
float max_descale_q = qkv_max / q_dtype_max;
float max_descale_k = qkv_max / k_dtype_max;
float max_descale_v = qkv_max / v_dtype_max;
ck_tile::FillUniformDistribution<float>{max_descale_q * 0.8f, max_descale_q, next_seed()}(
q_descale_host);
ck_tile::FillUniformDistribution<float>{max_descale_k * 0.8f, max_descale_k, next_seed()}(
k_descale_host);
ck_tile::FillUniformDistribution<float>{max_descale_v * 0.8f, max_descale_v, next_seed()}(
v_descale_host);
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);