From 374e0626e66caa3a519fb6b8f11e4051ffa2fdab Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 6 May 2025 08:24:03 +0000 Subject: [PATCH] Remove using cast_tile_pk_fp16_fp32 for better accuracy for fp16 hstu attention --- .../18_hstu_attention/example_hstu_attention.cpp | 4 ++-- .../18_hstu_attention/hstu_attention_fwd_pipeline.hpp | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index a1cdbd6663..bafc573a7b 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara template auto get_elimit() { - double rtol = 2e-2; - double atol = 2e-2; + double rtol = 1e-2; + double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index cca5f2dc48..283c287341 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -446,14 +446,8 @@ struct HstuAttentionFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); } - auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - else - return cast_tile( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - }(); + auto p = cast_tile( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); block_sync_lds();