From 27f7ab4f2c3074b18d074da0416a1d3a194eda4f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 25 Apr 2025 15:04:41 +0000 Subject: [PATCH] Use compiler builtin directly in f_silu for float type --- .../18_hstu_attention/hstu_attention_fwd_pipeline.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index ea82f9c43e..59ad8e7dfe 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -247,7 +247,14 @@ struct HstuAttentionFwdPipelineQRKSVS const auto f_silu = [](CompDataType& x) { const auto neg_one = ck_tile::type_convert(-1.0f); - x = x * __builtin_amdgcn_rcpf(neg_one - exp(x)); + if constexpr(std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)); + } + else + { + x = x / (neg_one - exp(x)); + } }; using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());