Use compiler builtin directly in f_silu for float type

This commit is contained in:
Qianfeng Zhang
2025-04-25 15:04:41 +00:00
parent 4ae9acd712
commit 27f7ab4f2c

View File

@@ -247,7 +247,14 @@ struct HstuAttentionFwdPipelineQRKSVS
const auto f_silu = [](CompDataType& x) {
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
x = x * __builtin_amdgcn_rcpf(neg_one - exp(x));
if constexpr(std::is_same_v<CompDataType, float>)
{
x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x));
}
else
{
x = x / (neg_one - exp(x));
}
};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());