This commit is contained in:
coderfeli
2025-03-25 03:01:27 +00:00
parent 2b15b67b3f
commit 0d266bfd65
2 changed files with 5 additions and 4 deletions

View File

@@ -1423,8 +1423,8 @@ struct GridwiseMoeGemm
constexpr auto cidx = Number<c_offset>{};
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
// gate = gate * math::rcp(1.0 + math::exp(-gate));
c_thread_buf(cidx) = gate + up;
gate = gate * math::rcp(1.0 + math::exp(-gate));
c_thread_buf(cidx) = gate * up;
});
});
});

View File

@@ -144,9 +144,10 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg.c_element_op_(v_c, v_acc);
arg.c_element_op_(v_c_up, v_acc_up);
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
v_c = v_c * (1.0 / (1.0 + math::exp(-v_c)));
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
// arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up * (1.0 / (1.0 + math::exp(-v_c_up)));
arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up;
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
// arg.c_t_k_n_(t, topk_id, n) = v_c + v_c_up;
}
};