diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index a1cdbd6663..bafc573a7b 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara template auto get_elimit() { - double rtol = 2e-2; - double atol = 2e-2; + double rtol = 1e-2; + double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index cca5f2dc48..283c287341 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -446,14 +446,8 @@ struct HstuAttentionFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); } - auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - else - return cast_tile( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - }(); + auto p = cast_tile( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); block_sync_lds();