From 319e6a7a659668d4de612bd91f01ede7aaed69a0 Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Thu, 13 Feb 2025 15:34:34 +0800 Subject: [PATCH] porting fmoe_sorting from moe_sorting (#1884) * porting fmoe_sorting from moe_sorting * pass default example test * remod [ROCm/composable_kernel commit: 0e5e29c4e2d3d012156982e791cbe925d5dca8fa] --- example/ck_tile/15_fused_moe/fused_moe.hpp | 19 +-- .../ck_tile/15_fused_moe/fused_moesorting.hpp | 3 +- .../15_fused_moe/instances/fused_moe_api.cpp | 3 +- .../instances/fused_moesorting_api.cpp | 108 ++++++++++-------- example/ck_tile/15_fused_moe/main.cpp | 60 ++++++---- 5 files changed, 108 insertions(+), 85 deletions(-) diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index 9c4e7b09ca..1f2246fa4a 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -8,14 +8,15 @@ struct fused_moe_args { - const void* a_ptr; // [m, k], input token - const void* a_scale_ptr; // [m, 1], token scale - const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) - const void* g_scale_ptr; // [e, 1, n], gate(up) scale - const void* d_scale_ptr; // [e, 1, k], down scale - const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input - void* o_ptr; // [m, k], output token (no need to do zeroing) + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP + void* o_ptr; // [m, k], output token (no need to do zeroing) const void* topk_ids_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk] @@ -48,6 +49,8 @@ struct fused_moe_traits int activation; // 0:gelu, 1:silu int gate_only; // 0:g1u0, 1:g1u1 int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant + + bool local_expert_masking; // if mask experts as local expert }; float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp index 57dace9b41..a3ff8c5bf7 100644 --- a/example/ck_tile/15_fused_moe/fused_moesorting.hpp +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -10,7 +10,8 @@ struct fused_moesorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index d29e4fd4fd..cf9ff2edba 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf return 1; }(); - auto t0 = fused_moesorting_trait{"int32", "fp32"}; + auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; auto a0 = fused_moesorting_args{ a.topk_ids_ptr, // const void* p_topk_ids; a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 805cd54878..7aedaa9317 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -24,20 +24,63 @@ return ave_time; #else -#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \ - constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ - constexpr bool sub_token_onshot = sub_token_onshot_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemEx; \ - using kernel = ck_tile::MoeSortingKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - const auto lds_bytes = kernel::GetSmemSize(a); \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + #endif #if !MOE_SORTING_USE_EX_KERNEL @@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto sub_token_ = r_ - 2; r_ = (r_ - 2) / 8; bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; (void)c_; - if(is_sub_token_onshot) - { - if(r_ % 8 == 0) - { - MOE_SORTING_DISPATCH_(8, true); - } - else if(r_ % 4 == 0) - { - MOE_SORTING_DISPATCH_(4, true); - } - else if(r_ % 2 == 0) - { - MOE_SORTING_DISPATCH_(2, true); - } - else - { - MOE_SORTING_DISPATCH_(1, true); - } - } - else - { - if(r_ % 8 == 0) - { - MOE_SORTING_DISPATCH_(8, false); - } - else if(r_ % 4 == 0) - { - MOE_SORTING_DISPATCH_(4, false); - } - else if(r_ % 2 == 0) - { - MOE_SORTING_DISPATCH_(2, false); - } - else - { - MOE_SORTING_DISPATCH_(1, false); - } - } + + MOE_SORTING_DISPATCH_EMASK_(r_); // MOE_SORTING_DISPATCH_ETILE(0, 0); #endif } diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 51611a67bc..95adcd684b 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t activation = arg_parser.get_int("act"); if(stride < 0) stride = hidden_size; - std::string prec_i = arg_parser.get_str("prec_i"); - std::string prec_w = arg_parser.get_str("prec_w"); - std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_st = arg_parser.get_str("prec_st"); - std::string prec_sw = arg_parser.get_str("prec_sw"); - std::string prec_sq = arg_parser.get_str("prec_sq"); - std::string prec_kw = arg_parser.get_str("prec_kw"); - prec_st = (prec_st == "auto") ? "fp32" : prec_st; - prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; - prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; - prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int fused_quant = arg_parser.get_int("fquant"); - int gate_only = arg_parser.get_int("gate_only"); - int api = arg_parser.get_int("api"); - int balance = arg_parser.get_int("balance"); - int tp = arg_parser.get_int("tp"); - int init = arg_parser.get_int("init"); - uint32_t seed = arg_parser.get_uint32("seed"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int api = arg_parser.get_int("api"); + int balance = arg_parser.get_int("balance"); + int tp = arg_parser.get_int("tp"); + int init = arg_parser.get_int("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + bool local_expert_masking = false; // TODO... // w0 (Gate+Up or Gate only, N size) ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; @@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort + ck_tile::HostTensor local_expert_mask_host({experts}); int max_num_tokens_padded = topk * tokens + experts * block_m - topk; ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); @@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host); @@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser) block_m, activation, gate_only, - fused_quant}; + fused_quant, + local_expert_masking}; fused_moe_args args{a_buf.GetDeviceBuffer(), fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, @@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() + : nullptr, o_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(), @@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, + local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, - block_m); + block_m, + local_expert_masking); if(activation == 0) { CPU_FUSED_MOE(ck_tile::element_wise::Gelu); @@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, + local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, - block_m); + block_m, + local_expert_masking); // done, preparing GPU buffer ck_tile::DeviceMem a_buf(a_host);