diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 69e9530bc8..4c17a91688 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -199,8 +199,8 @@ auto get_elimit() template <> auto get_elimit() { - double rtol = 2e-3; - double atol = 2e-3; + double rtol = 1e-3; + double atol = 1e-3; return ck_tile::make_tuple(rtol, atol); } diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 2319a893e6..9eeaaa35be 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -102,7 +102,7 @@ struct reference_hstu_attention auto silu = [&](CompDataType x) { const auto one = ck_tile::type_convert(1.0f); - return x / (one + std::exp(-x)) / ck_tile::type_convert(max_seqlen); + return x / (one + std::exp(-x)); }; auto f = [&](auto i_batch, auto i_head) { @@ -199,6 +199,8 @@ struct reference_hstu_attention }; }; + dot_prod = dot_prod / ck_tile::type_convert(max_seqlen); + if constexpr(kIsJagged) o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = ck_tile::type_convert(dot_prod);