apply scale p

This commit is contained in:
rocking
2025-09-01 15:16:28 -05:00
parent b7dacea7c9
commit 2c297ba201

View File

@@ -1478,20 +1478,51 @@ struct FmhaFwdKernel
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
float invscale_p = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float scale_p = 1.f / invscale_p;
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{invscale_p}, // p_compute_element_func
scales{scale_p}, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
}
}();