From 0d266bfd65eee8f7104cd7b3be8e40c392569ccf Mon Sep 17 00:00:00 2001 From: coderfeli Date: Tue, 25 Mar 2025 03:01:27 +0000 Subject: [PATCH] add silu --- include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp | 4 ++-- .../reference_tensor_operation/cpu/reference_moe_gemm.hpp | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 0bebfbb45d..d203fc40fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1423,8 +1423,8 @@ struct GridwiseMoeGemm constexpr auto cidx = Number{}; auto gate = scale_a * scale_gate * c_thread_buf[cidx]; auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - // gate = gate * math::rcp(1.0 + math::exp(-gate)); - c_thread_buf(cidx) = gate + up; + gate = gate * math::rcp(1.0 + math::exp(-gate)); + c_thread_buf(cidx) = gate * up; }); }); }); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 14de5eda54..5825a6e1ae 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -144,9 +144,10 @@ struct ReferenceMoeGemm : public device::BaseOperator arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c_up, v_acc_up); v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + v_c = v_c * (1.0 / (1.0 + math::exp(-v_c))); v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); - // arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up * (1.0 / (1.0 + math::exp(-v_c_up))); - arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up; + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + // arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up; } };