mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fix block scale init value (#3666)
* Make blockscale descale range adaptive to data type max value * format
This commit is contained in:
@@ -750,9 +750,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
|
||||
float qkv_max = 3.f;
|
||||
float max_descale_q = qkv_max / q_dtype_max;
|
||||
float max_descale_k = qkv_max / k_dtype_max;
|
||||
float max_descale_v = qkv_max / v_dtype_max;
|
||||
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_q * 0.8f, max_descale_q, next_seed()}(
|
||||
q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_k * 0.8f, max_descale_k, next_seed()}(
|
||||
k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_v * 0.8f, max_descale_v, next_seed()}(
|
||||
v_descale_host);
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
|
||||
Reference in New Issue
Block a user