Combine minus with scale_s

This commit is contained in:
Qianfeng Zhang
2025-04-24 05:47:24 +00:00
parent ce4665262b
commit aec19176d4
2 changed files with 5 additions and 5 deletions

View File

@@ -202,7 +202,7 @@ struct HstuAttentionFwdKernel
hdim_qk,
hdim_v,
num_head,
scale_s,
-scale_s,
seq_stride_q,
seq_stride_k,
seq_stride_v,
@@ -348,7 +348,7 @@ struct HstuAttentionFwdKernel
hdim_qk,
hdim_v,
num_head,
scale_s,
-scale_s,
seq_stride_q,
seq_stride_k,
seq_stride_v,

View File

@@ -245,9 +245,9 @@ struct HstuAttentionFwdPipelineQRKSVS
// reduction function for softmax
const auto f_silu = [](CompDataType& x) {
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
return x = x / (one + exp(-x));
return x = x / (neg_one - exp(x));
};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
@@ -338,7 +338,7 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<GemmAccDataType>(bias_element_func(y));
x = x * scale_s - type_convert<GemmAccDataType>(bias_element_func(y));
},
sacc_tiles[i_k1],
bias_tile);