diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index dab1620fb7..fc3a2e8bd0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -450,6 +450,9 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout(f_silu, pcomp_tile); + tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, + pcomp_tile); + seqlen_k_curr += kN0; if constexpr(kHasDropout) @@ -507,9 +510,6 @@ struct HstuAttentionFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); - tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, - o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 08f561620a..4b35ef3884 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -191,7 +191,7 @@ struct reference_hstu_attention // SiLu element-wise for(CompDataType& elem : locals) - elem = silu(elem); + elem = silu(elem) * ck_tile::type_convert(scale_p); // second Gemm for(int k = 0; k < hdim_v; k++) @@ -221,8 +221,6 @@ struct reference_hstu_attention }; }; - dot_prod = dot_prod * ck_tile::type_convert(scale_p); - if constexpr(kIsJagged) o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = ck_tile::type_convert(dot_prod);