diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 8910a78e5e..c9a4c8bc5a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -278,6 +278,19 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< using Base::NumDTensor; static constexpr auto BlockSizeNumber = Number{}; + // Clamp limit for swiglustep_and_mul: silu(g).clamp(max=L) * u.clamp(+-L), L hardcoded to 7.0 + static constexpr float kSwiGluClamp = 7.0f; + + // Helper: apply SwiGLU-step activation (silu + symmetric clamp) and return gate*up. + // Used by all four swiglustep_and_mul epilogue paths (quant/non-quant x pipeline-A/B). + __host__ __device__ static constexpr float apply_swiglustep_activation(float gate, float up) + { + tensor_operation::element_wise::Silu{}(gate, gate); + gate = math::min(gate, kSwiGluClamp); + up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp); + return gate * up; + } + using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); @@ -1453,6 +1466,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1514,6 +1547,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1926,6 +1971,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1987,6 +2052,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else {