Use type_convert rather than static_cast in f_silu

This commit is contained in:
Qianfeng Zhang
2025-05-07 07:05:43 +00:00
parent 72d55d1b40
commit 079f7e3a03

View File

@@ -239,11 +239,11 @@ struct HstuAttentionFwdPipelineQRKSVS
if constexpr(std::is_same_v<CompDataType, float>)
{
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)) *
__builtin_amdgcn_rcpf(static_cast<CompDataType>(max_seqlen));
__builtin_amdgcn_rcpf(type_convert<CompDataType>(max_seqlen));
}
else
{
x = x / (neg_one - exp(x)) / static_cast<CompDataType>(max_seqlen);
x = x / (neg_one - exp(x)) / type_convert<CompDataType>(max_seqlen);
}
};