From f17537b7c2051ed8374299da30b20bf8b72981f2 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 7 May 2026 13:59:47 +0800 Subject: [PATCH] [CK] add swiglustep_and_mul activation to gridwise_moe_gemm (#6873) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Title: feat(composablekernel): add swiglustep_and_mul activation to gridwise_moe_gemm Description: ## Motivation Step-3.5-Flash uses a clamped SwiGLU activation (`swiglu_limits[43]=7`, `swiglu_limits[44]=7`) for layers 43 and 44. Without this kernel path, those layers produce BOS token spam because unclamped gate/up values accumulate floating-point noise over 200+ decode steps, degrading output quality (cosine similarity drops from 0.999989 to ~0.998982). ## Changes Add `swiglustep_and_mul` as a new `Activation` enum branch in `gridwise_moe_gemm.hpp`, covering all 4 code paths: - Quantized (A×B scale) + IsInputGemm=true - Quantized (A×B scale) + IsInputGemm=false - Non-quantized + IsInputGemm=true - Non-quantized + IsInputGemm=false The activation computes: gate = silu(gate) gate = clamp(gate, max=7.0f) up = clamp(up, min=-7.0f, max=7.0f) output = gate * up Also handles the `MulRoutedWeight` case (topk weight multiplication) and `pk_i4_t` weight scaling (×16 dequant factor). ## Verification - Tested on gfx950 (MI350X, 8×GPU) - cosine similarity for layers 43/44: **0.999989** (vs 0.998982 before fix) - End-to-end Step-3.5-Flash inference: no BOS spam, output coherent - BF16 tp=2/tp=4 and FP8 tp=2/tp=4 all verified PASS - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../gpu/grid/gridwise_moe_gemm.hpp | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 8910a78e5e..c9a4c8bc5a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -278,6 +278,19 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< using Base::NumDTensor; static constexpr auto BlockSizeNumber = Number{}; + // Clamp limit for swiglustep_and_mul: silu(g).clamp(max=L) * u.clamp(+-L), L hardcoded to 7.0 + static constexpr float kSwiGluClamp = 7.0f; + + // Helper: apply SwiGLU-step activation (silu + symmetric clamp) and return gate*up. + // Used by all four swiglustep_and_mul epilogue paths (quant/non-quant x pipeline-A/B). + __host__ __device__ static constexpr float apply_swiglustep_activation(float gate, float up) + { + tensor_operation::element_wise::Silu{}(gate, gate); + gate = math::min(gate, kSwiGluClamp); + up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp); + return gate * up; + } + using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); @@ -1453,6 +1466,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1514,6 +1547,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1926,6 +1971,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else { @@ -1987,6 +2052,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up); + } } else {