From aec19176d4f0b6e20358afff7784f44964be9b02 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 24 Apr 2025 05:47:24 +0000 Subject: [PATCH] Combine minus with scale_s --- .../ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp | 4 ++-- .../18_hstu_attention/hstu_attention_fwd_pipeline.hpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 8bd68d7fd3..92351ecc69 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -202,7 +202,7 @@ struct HstuAttentionFwdKernel hdim_qk, hdim_v, num_head, - scale_s, + -scale_s, seq_stride_q, seq_stride_k, seq_stride_v, @@ -348,7 +348,7 @@ struct HstuAttentionFwdKernel hdim_qk, hdim_v, num_head, - scale_s, + -scale_s, seq_stride_q, seq_stride_k, seq_stride_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index e68f3ad85e..ea8d60ef6f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -245,9 +245,9 @@ struct HstuAttentionFwdPipelineQRKSVS // reduction function for softmax const auto f_silu = [](CompDataType& x) { - const auto one = ck_tile::type_convert(1.0f); + const auto neg_one = ck_tile::type_convert(-1.0f); - return x = x / (one + exp(-x)); + return x = x / (neg_one - exp(x)); }; using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); @@ -338,7 +338,7 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout( [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); + x = x * scale_s - type_convert(bias_element_func(y)); }, sacc_tiles[i_k1], bias_tile);