diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 6e047dd64a..2f9a9cd21b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -28,8 +28,9 @@ namespace ck { enum Activation { - gelu_and_mul = 0, - silu_and_mul = 1 + gelu_and_mul = 0, + silu_and_mul = 1, + swiglustep_and_mul = 2 }; template , pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if constexpr(ActivationOperation == Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx];