Add swiglustep_and_mul branches to gridwise_moe_gemm (4 paths, hardcoded 7.0f clamp)

This commit is contained in:
Jun Lin
2026-04-23 21:16:50 +00:00
parent fdf4bb7fcc
commit defd7ad297

View File

@@ -1450,6 +1450,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
@@ -1511,6 +1534,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
@@ -1923,6 +1961,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
@@ -1984,6 +2045,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
gate = gate < 7.0f ? gate : 7.0f;
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{