[CK] add swiglustep_and_mul activation to gridwise_moe_gemm (#6873)

Title:
feat(composablekernel): add swiglustep_and_mul activation to
gridwise_moe_gemm

  Description:
  ## Motivation

Step-3.5-Flash uses a clamped SwiGLU activation (`swiglu_limits[43]=7`,
  `swiglu_limits[44]=7`) for layers 43 and 44. Without this kernel path,
  those layers produce BOS token spam because unclamped gate/up values
  accumulate floating-point noise over 200+ decode steps, degrading
  output quality (cosine similarity drops from 0.999989 to ~0.998982).

  ## Changes

  Add `swiglustep_and_mul` as a new `Activation` enum branch in
  `gridwise_moe_gemm.hpp`, covering all 4 code paths:
  - Quantized (A×B scale) + IsInputGemm=true
  - Quantized (A×B scale) + IsInputGemm=false
  - Non-quantized + IsInputGemm=true
  - Non-quantized + IsInputGemm=false

  The activation computes:
  gate = silu(gate)
  gate = clamp(gate, max=7.0f)
  up   = clamp(up,   min=-7.0f, max=7.0f)
  output = gate * up

Also handles the `MulRoutedWeight` case (topk weight multiplication) and
  `pk_i4_t` weight scaling (×16 dequant factor).

  ## Verification

  - Tested on gfx950 (MI350X, 8×GPU)
- cosine similarity for layers 43/44: **0.999989** (vs 0.998982 before
fix)
  - End-to-end Step-3.5-Flash inference: no BOS spam, output coherent
  - BF16 tp=2/tp=4 and FP8 tp=2/tp=4 all verified PASS
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-05-07 13:59:47 +08:00
committed by GitHub
parent 1cf336d87a
commit f17537b7c2

View File

@@ -278,6 +278,19 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
using Base::NumDTensor;
static constexpr auto BlockSizeNumber = Number<BlockSize>{};
// Clamp limit for swiglustep_and_mul: silu(g).clamp(max=L) * u.clamp(+-L), L hardcoded to 7.0
static constexpr float kSwiGluClamp = 7.0f;
// Helper: apply SwiGLU-step activation (silu + symmetric clamp) and return gate*up.
// Used by all four swiglustep_and_mul epilogue paths (quant/non-quant x pipeline-A/B).
__host__ __device__ static constexpr float apply_swiglustep_activation(float gate, float up)
{
tensor_operation::element_wise::Silu{}(gate, gate);
gate = math::min(gate, kSwiGluClamp);
up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp);
return gate * up;
}
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
@@ -1453,6 +1466,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
}
else
{
@@ -1514,6 +1547,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
}
else
{
@@ -1926,6 +1971,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
}
else
{
@@ -1987,6 +2052,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if constexpr(ActivationOperation ==
Activation::swiglustep_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.template AsType<float>()[m4];
up = up * topk_weights.template AsType<float>()[m4];
}
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
}
}
else
{