mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
[CK] add swiglustep_and_mul activation to gridwise_moe_gemm (#6873)
Title: feat(composablekernel): add swiglustep_and_mul activation to gridwise_moe_gemm Description: ## Motivation Step-3.5-Flash uses a clamped SwiGLU activation (`swiglu_limits[43]=7`, `swiglu_limits[44]=7`) for layers 43 and 44. Without this kernel path, those layers produce BOS token spam because unclamped gate/up values accumulate floating-point noise over 200+ decode steps, degrading output quality (cosine similarity drops from 0.999989 to ~0.998982). ## Changes Add `swiglustep_and_mul` as a new `Activation` enum branch in `gridwise_moe_gemm.hpp`, covering all 4 code paths: - Quantized (A×B scale) + IsInputGemm=true - Quantized (A×B scale) + IsInputGemm=false - Non-quantized + IsInputGemm=true - Non-quantized + IsInputGemm=false The activation computes: gate = silu(gate) gate = clamp(gate, max=7.0f) up = clamp(up, min=-7.0f, max=7.0f) output = gate * up Also handles the `MulRoutedWeight` case (topk weight multiplication) and `pk_i4_t` weight scaling (×16 dequant factor). ## Verification - Tested on gfx950 (MI350X, 8×GPU) - cosine similarity for layers 43/44: **0.999989** (vs 0.998982 before fix) - End-to-end Step-3.5-Flash inference: no BOS spam, output coherent - BF16 tp=2/tp=4 and FP8 tp=2/tp=4 all verified PASS - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -278,6 +278,19 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
using Base::NumDTensor;
|
||||
static constexpr auto BlockSizeNumber = Number<BlockSize>{};
|
||||
|
||||
// Clamp limit for swiglustep_and_mul: silu(g).clamp(max=L) * u.clamp(+-L), L hardcoded to 7.0
|
||||
static constexpr float kSwiGluClamp = 7.0f;
|
||||
|
||||
// Helper: apply SwiGLU-step activation (silu + symmetric clamp) and return gate*up.
|
||||
// Used by all four swiglustep_and_mul epilogue paths (quant/non-quant x pipeline-A/B).
|
||||
__host__ __device__ static constexpr float apply_swiglustep_activation(float gate, float up)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = math::min(gate, kSwiGluClamp);
|
||||
up = math::min(math::max(up, -kSwiGluClamp), kSwiGluClamp);
|
||||
return gate * up;
|
||||
}
|
||||
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
@@ -1453,6 +1466,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1514,6 +1547,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1926,6 +1971,26 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1987,6 +2052,18 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base<
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation ==
|
||||
Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.template AsType<float>()[m4];
|
||||
up = up * topk_weights.template AsType<float>()[m4];
|
||||
}
|
||||
c_thread_buf_fp32(cidx) = apply_swiglustep_activation(gate, up);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user