From c9e8acc56a60dbed2e9c546198c07d1f153ac3b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=91=E9=BB=84=E8=89=B2=E8=91=A1=E8=90=84=E7=90=83?= =?UTF-8?q?=E5=90=9B=E5=90=9B?= Date: Tue, 21 Apr 2026 15:24:48 +0800 Subject: [PATCH] projects/composablekernel: add SwigluStep support for MoE blockscale (#6118) ## Summary - add `swiglustep_and_mul` to the composablekernel MoE blockscale activation enum - implement the corresponding blockscale epilogue path for `SwigluStep` - keep existing `silu` and `gelu` paths unchanged ## Scope This PR covers the classic composablekernel blockscale MoE path under `projects/composablekernel`. This is separate from the `ck_tile` / FlatMM path being discussed in ROCm/rocm-libraries#5992. ## Motivation `Step-3.5-Flash-FP8` uses `SwigluStep` in its MoE MLP path. The dependent AITER change needs native support for this activation in the classic composablekernel MoE blockscale path. ## Validation - patch is limited to two composablekernel files under `projects/composablekernel` - existing `silu` / `gelu` paths are unchanged - dependent AITER runtime validation hit the classic CK 2-stage path with AITER MoE enabled --- .../gridwise_gemm_xdl_cshuffle_common.hpp | 5 ++- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 6e047dd64a..2f9a9cd21b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -28,8 +28,9 @@ namespace ck { enum Activation { - gelu_and_mul = 0, - silu_and_mul = 1 + gelu_and_mul = 0, + silu_and_mul = 1, + swiglustep_and_mul = 2 }; template , pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if constexpr(ActivationOperation == Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx];