diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index ea82f9c43e..59ad8e7dfe 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -247,7 +247,14 @@ struct HstuAttentionFwdPipelineQRKSVS const auto f_silu = [](CompDataType& x) { const auto neg_one = ck_tile::type_convert(-1.0f); - x = x * __builtin_amdgcn_rcpf(neg_one - exp(x)); + if constexpr(std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)); + } + else + { + x = x / (neg_one - exp(x)); + } }; using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());