mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
add silu
This commit is contained in:
@@ -1423,8 +1423,8 @@ struct GridwiseMoeGemm
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
// gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
c_thread_buf(cidx) = gate + up;
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -144,9 +144,10 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
arg.c_element_op_(v_c_up, v_acc_up);
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
v_c = v_c * (1.0 / (1.0 + math::exp(-v_c)));
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
// arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up * (1.0 / (1.0 + math::exp(-v_c_up)));
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up;
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
// arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user