mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Combine minus with scale_s
This commit is contained in:
@@ -202,7 +202,7 @@ struct HstuAttentionFwdKernel
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
-scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
@@ -348,7 +348,7 @@ struct HstuAttentionFwdKernel
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
-scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
|
||||
@@ -245,9 +245,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType& x) {
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
const auto neg_one = ck_tile::type_convert<CompDataType>(-1.0f);
|
||||
|
||||
return x = x / (one + exp(-x));
|
||||
return x = x / (neg_one - exp(x));
|
||||
};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
@@ -338,7 +338,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
tile_elementwise_inout(
|
||||
[&scale_s, &bias_element_func](auto& x, const auto& y) {
|
||||
x = x * scale_s + type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
x = x * scale_s - type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
sacc_tiles[i_k1],
|
||||
bias_tile);
|
||||
|
||||
Reference in New Issue
Block a user