mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Add swiglustep_and_mul branches to gridwise_moe_gemm (4 paths, hardcoded 7.0f clamp)
This commit is contained in:
@@ -1450,6 +1450,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1511,6 +1534,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1923,6 +1961,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1984,6 +2045,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user