From 03533374477ff5564ef78d01ac1b5f3ba54aa38f Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Fri, 11 Apr 2025 02:18:16 +0000 Subject: [PATCH] refine activation code & complete moe example --- .../moe_gemm1_xdl_fp8.cpp | 2 +- .../moe_gemm1_xdl_pk_i4.cpp | 10 +-- .../gpu/grid/gridwise_moe_gemm.hpp | 77 ++++++++----------- .../cpu/reference_moe_gemm.hpp | 46 +++++------ 4 files changed, 59 insertions(+), 76 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 4cc20377dd..f428bb22f3 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -131,7 +131,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index f46aa45d9b..e83fc22988 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -62,7 +62,7 @@ struct MulABScale (void)d0; (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * 16); + e = ck::type_convert(c); #else e = ck::type_convert(c); #endif @@ -74,7 +74,7 @@ struct MulABScale (void)d0; (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * 16); + e = ck::type_convert(c); #else e = ck::type_convert(c); #endif @@ -125,7 +125,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t ActOP = 2; // 0: gelu_and_mul, 2: silu_and_mul +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, @@ -203,7 +203,7 @@ int main(int argc, char* argv[]) { expert_ids.mData[i] = eids[i]; } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; for(int i = 0; i < sorted_size; i++) { @@ -479,7 +479,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 7582669e08..beb6a76392 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -29,9 +29,8 @@ namespace ck { enum Activation { - gelu = 0, - silu = 1, - swiglu = 2 + gelu_and_mul = 0, + silu_and_mul = 1 }; template scale_token_ids; vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave, 1>{}([&](auto n0) { @@ -1462,27 +1463,7 @@ struct GridwiseMoeGemm constexpr auto cidx = Number{}; if constexpr(IsInputGemm) // gu fusion { - if constexpr(ActivationOperation == Activation::silu) - { - tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::gelu) - { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - auto gate = scale_a * scale_b * c_thread_buf[cidx]; - auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf(cidx) = gate * up; - } - else if(ActivationOperation == Activation::swiglu) + if constexpr(ActivationOperation == Activation::silu_and_mul) { const float scale_up = p_scale_b[(n0 * NWave * NPerXdl + problem.N) * @@ -1497,6 +1478,21 @@ struct GridwiseMoeGemm tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } } else { @@ -1511,10 +1507,6 @@ struct GridwiseMoeGemm } else { - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); - const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave, 1>{}([&](auto n0) { static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave @@ -1533,25 +1525,20 @@ struct GridwiseMoeGemm constexpr auto cidx = Number{}; if constexpr(IsInputGemm) // gu fusion { - if constexpr(ActivationOperation == Activation::silu) - { - tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::gelu) - { - auto gate = c_thread_buf[cidx]; - auto up = c_thread_buf_up[cidx]; - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf(cidx) = gate * up; - } - else if(ActivationOperation == Activation::swiglu) + if constexpr(ActivationOperation == Activation::silu_and_mul) { auto gate = c_thread_buf[cidx]; auto up = c_thread_buf_up[cidx]; tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if(ActivationOperation == Activation::gelu_and_mul) + { + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } } else { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index a50307e86f..9260e812bb 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -80,10 +80,7 @@ struct ReferenceMoeGemm : public device::BaseOperator float Run(const Argument& arg) { - if constexpr(ActivationType > 2) - { - static_assert(false, "Not supported activation type"); - } + static_assert(ActivationType < 2, "Not supported activation type"); const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_.mDesc.GetLengths()[1]; @@ -148,44 +145,43 @@ struct ReferenceMoeGemm : public device::BaseOperator else { arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); - if constexpr(ActivationType == 2) - { - arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); - } + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); - - if constexpr(ActivationType == 2) - { - v_acc_up += ck::type_convert(v_a) * - ck::type_convert(v_b_up); - } + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); } CDataType v_c{0}; CDataType v_c_up{0}; arg.c_element_op_(v_c, v_acc); - if constexpr(ActivationType == 2) + if constexpr(ActivationType == 1) { arg.c_element_op_(v_c_up, v_acc_up); - v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); - v_c = v_c * (1.0 / (1.0 + math::exp(-v_c))); + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Silu{}(v_c, v_c); v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; } - else + else if constexpr(ActivationType == 0) { - if constexpr(ActivationType == 1) + arg.c_element_op_(v_c_up, v_acc_up); + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) { - tensor_operation::element_wise::Silu{}(v_c, v_c); + v_c_up *= 16; + v_c *= 16; } - else if constexpr(ActivationType == 0) - { - tensor_operation::element_wise::Gelu{}(v_c, v_c); - } - arg.c_t_k_n_(t, topk_id, n) = v_c; + tensor_operation::element_wise::Gelu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; } } };