From 8ed234da8c92c440524846974b4bd9e5ec68eb03 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 11 Feb 2025 17:49:17 +0800 Subject: [PATCH] [CK_TILE] moe sorting ex kernel to support expert > 128 (#1840) * moe sorting ex * fix bug for race condition * fix bug and optimze large expert * fix * optimize with sub_token_oneshot * support skip empty tokens for expert sorting * update moe_sorting * tidy code [ROCm/composable_kernel commit: c0adab485020b83f324d2efdcac2c997e19443eb] --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 63 +- .../13_moe_sorting/moe_sorting_api.cpp | 82 +++ .../13_moe_sorting/moe_sorting_api.hpp | 3 +- .../13_moe_sorting/script/smoke_test.sh | 8 + example/ck_tile/15_fused_moe/README.md | 2 +- .../instances/fused_moesorting_api.cpp | 74 ++ .../host/reference/reference_moe_sorting.hpp | 26 +- include/ck_tile/ops/fused_moe.hpp | 2 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 693 ++++++++++++++++-- .../fused_moe/kernel/moe_sorting_problem.hpp | 52 ++ .../pipeline/moe_sorting_problem.hpp | 28 - 12 files changed, 936 insertions(+), 99 deletions(-) create mode 100644 include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp delete mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d2c4df1058..c4faa35e33 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) .insert("k", "4", "topk") .insert("unit", "32", "unit_size") .insert("moe_buf_size", "0", "moe_buf_size") + .insert("local_eid", + "-1", + "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" + "please make sure eid is in ascending order!") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") @@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) return false; } + bool local_expert_masking = args.get_str("local_eid") != "-1"; + auto local_expert_masking_host = [&]() { + if(local_expert_masking) + { + auto local_eid = args.get_int_vec("local_eid"); + // std::vector v_ {num_experts, 0}; + ck_tile::HostTensor v_{{num_experts}}; + v_.SetZero(); + for(auto eid : local_eid) + { + if(eid >= num_experts) + { + throw std::runtime_error( + "local_eid larger than number of expert, please check"); + } + v_.mData[eid] = 1; + } + return v_; + } + else + // return std::vector{}; + return ck_tile::HostTensor{{1}}; + }(); + // tokens already considered batch size ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); @@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem local_expert_masking_dev( + local_expert_masking_host.get_element_space_size_in_bytes()); topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); @@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) { moe_buf_dev.ToDevice(moe_buf_host.data()); } + if(local_expert_masking) + local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); - moe_sorting_trait trait{index_prec, weight_prec}; + moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), @@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) warmup, repeat}; auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk, - ms); + topk); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + if(ms < 0) printf("not supported\n"); + else + printf("ms:%f, ", ms); fflush(stdout); if(ms < 0) { @@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) int32_t ref_total_tokens_post_pad = 0; ck_tile::reference_moe_sorting(topk_ids_host, weights_host, + local_expert_masking_host, sorted_ids_ref, sorted_weights_ref, sorted_expert_ids_ref, ref_total_tokens_post_pad, num_experts, - unit_size); + unit_size, + local_expert_masking); rtn &= ck_tile::check_err( sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); rtn &= ck_tile::check_err(sorted_weights_host, @@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("valid:%s", rtn ? "y" : "n"); + fflush(stdout); + if(!rtn) + printf(", (%d)", seed); + printf("\n"); fflush(stdout); return rtn; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 723fb3f69f..abff24a669 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,6 +3,12 @@ #include "moe_sorting_api.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,67 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +105,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + (void)c_; + + MOE_SORTING_DISPATCH_EMASK_(r_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0cb393f7de..5bda4d368a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,7 +10,8 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 3ff8a7332d..cf2c2e164b 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 $EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md index b6ceabf351..089e1de78e 100644 --- a/example/ck_tile/15_fused_moe/README.md +++ b/example/ck_tile/15_fused_moe/README.md @@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7ca24c5c9a..805cd54878 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,6 +3,12 @@ #include "fused_moesorting.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,24 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +62,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + (void)c_; + if(is_sub_token_onshot) + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, true); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, true); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, true); + } + else + { + MOE_SORTING_DISPATCH_(1, true); + } + } + else + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, false); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, false); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, false); + } + else + { + MOE_SORTING_DISPATCH_(1, false); + } + } + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 3851629cc2..47f0ba576b 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -14,12 +14,15 @@ namespace ck_tile { template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, const HostTensor& weights, + const HostTensor& local_expert_mask, HostTensor& p_sorted_token_ids, HostTensor& sorted_weight, HostTensor& sorted_expert_ids, index_t& unit_cnt, const index_t experts, - const index_t unit_size) + const index_t unit_size, + bool local_expert_masking, + bool skip_experts_with_zero_token = true) { const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t topk = topk_ids.mDesc.get_lengths()[1]; @@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, #endif std::vector> expert_token_weights( experts, std::vector(unit_size, 0)); + // count number of unit-size slices in this expert std::vector expert_slices(experts, 1); + // count the tokens used in this expert std::vector expert_slice_idxs(experts, 0); + // TODO: above 2 buffer seems duplicated for(index_t t = 0; t < num_token; t++) { @@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, IndexType* out_tokens = p_sorted_token_ids.data(); WeightType* out_weights = sorted_weight.data(); IndexType* out_expert_id = sorted_expert_ids.data(); + int curr_expert_id = 0; for(index_t e = 0; e < experts; e++) { + if(local_expert_masking) + { + if(local_expert_mask(e) == 0) + continue; + } + if(skip_experts_with_zero_token) + { + if(expert_slice_idxs[e] == 0) + { + curr_expert_id++; + continue; + } + } + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); out_tokens += expert_slices[e] * unit_size; memcpy(out_weights, @@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, for(index_t s = 0; s < expert_slices[e]; s++) { - out_expert_id[s] = e; + out_expert_id[s] = curr_expert_id; unit_cnt++; } out_expert_id += expert_slices[e]; + curr_expert_id++; } unit_cnt *= unit_size; return; diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 3ffb0a9ca2..ddb64a2189 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" @@ -14,7 +15,6 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index a7eeb3c0e3..efa1ccb311 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -22,7 +22,7 @@ // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 30e68996b6..340f6cb9e5 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -15,6 +15,10 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -28,7 +32,7 @@ namespace ck_tile { // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] @@ -55,6 +59,34 @@ namespace ck_tile { // num_tokens_post_padded_ptr : [28] // num_sorted_tiles_ptr : [7] // +// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens) +// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5] +// num_tokens_post_padded_ptr : [24] +// +// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case) +// and modify the output expert-ID, because we will only have enbaled expert on specific GPU. +// we call expert input to this kernel as "global expert id", output as "local expert id" +// +// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id") +// num_tokens_post_padded_ptr : [20] +// // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 2)need sorted_weight_ptr @@ -67,10 +99,80 @@ namespace ck_tile { // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) + + +CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_) +{ + /* num_experts + 1 + * +--------------------------------------+ + * | | + * | | + * | | * -> sub-tokens + * | | + * | | + * +--------------------------------------+ + * | | 2 -> cumsum buffer + * +--------------------------------------+ + * + */ + int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here + int smem_rows = [&](){ + index_t target_occupancy_ = 2; + constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t sub_unroll = 8; + constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt + // at lease 2 lines, one for sub_token unroll, one for cumsum + // should be enough + if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) { + if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols)) + throw std::runtime_error("too many num_experts, can't allocate smem"); + target_occupancy_ = 1; + } + int r = total_ / target_occupancy_ / smem_cols; + + // round to sub_unroll multipl + int r_for_sub_token = r - cumsum_bufs; + r_for_sub_token = min(r_for_sub_token, num_tokens_); + r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll; + r_for_sub_token = max(r_for_sub_token, 1); + + if(r_for_sub_token > 1) + { + int r_unroll_ = r_for_sub_token / sub_unroll; + + + // round to 1x/2x/4x/8x number of sub_unroll + int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29 + int mask_ = (1 << (31 - clz_)) - 1; + + + mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most + mask_ = ~mask_; + //printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout); + + r_for_sub_token = (r_unroll_ & mask_) * sub_unroll; + } + + // final check + if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) { + throw std::runtime_error("can't run this kernel, request LDS over size"); + } + + return r_for_sub_token + cumsum_bufs; + }(); + + // printf("r:%d, c:%d\n", smem_rows, smem_cols); + + return ck_tile::make_tuple(smem_rows, smem_cols); +} + struct MoeSortingHostArgs { const void* p_topk_ids; // [token, topk] const void* p_weights; // [token, topk] + + const void* p_local_expert_mask; + void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -101,6 +203,7 @@ struct MoeSortingKernel { const void* p_topk_ids; const void* p_weights; + const void* p_local_expert_mask; void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -111,8 +214,11 @@ struct MoeSortingKernel index_t moe_buf_bytes; index_t tokens_per_thread; + index_t smem_rows; mdiv unit_size_mdiv; mdiv topk_mdiv; + mdiv expert_mdiv; + // mdiv sub_tokens_mdiv; }; CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) @@ -123,15 +229,25 @@ struct MoeSortingKernel CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + (void)h; + return dim3(256); +#else return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); +#endif } // in byte CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + return smem_rows * smem_cols * sizeof(int); +#else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); +#endif } CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -139,6 +255,7 @@ struct MoeSortingKernel Kargs k; k.p_topk_ids = h.p_topk_ids; k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -152,10 +269,18 @@ struct MoeSortingKernel k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.smem_rows = [&](){ + auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + (void) c_; + return r_; + }(); + k.expert_mdiv = mdiv{static_cast(h.num_experts)}; + // k.sub_tokens_mdiv = mdiv{static_cast(k.smem_rows - 1)}; return k; } - // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // NOTE: wave_size need at least be 16!! dpp 16 is one row template __device__ inline void wave_cumsum(data_t& thread_data) const { @@ -196,6 +321,40 @@ struct MoeSortingKernel bank_mask, bound_ctrl))); // row_shr:4 } + if constexpr(wave_size == 8) { + + // wave-size=8 need one extra shift + thread_data = + reduce_op(thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 +#if 0 + constexpr int bank_mask_0_7 = 0b1100; + auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; + thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, + __builtin_amdgcn_update_dpp(0, /* old value */ + __builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask_0_7, + bound_ctrl))// row_newbcast:7 + ); +#else + data_t xxx =__builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask, + bound_ctrl)); // row_newbcast:7 + + data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; + thread_data = thread_data - yyy; +#endif + + } if constexpr(wave_size > 8) { thread_data = @@ -224,6 +383,36 @@ struct MoeSortingKernel } } + // reduce single pixel within a wave + template + __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) + { + // constexpr int wave_size = 64; + // constexpr int reduce_stage = 6; // 1<<6=64 + // clang-format off + constexpr int reduce_stage = [](){ + if constexpr(wave_size_ == 2) return 1; + else if constexpr(wave_size_ == 4) return 2; + else if constexpr(wave_size_ == 8) return 3; + else if constexpr(wave_size_ == 16) return 4; + else if constexpr(wave_size_ == 32) return 5; + else if constexpr(wave_size_ == 64) return 6; + else return 0; + }(); + // clang-format on + T v_local = local; +#pragma unroll reduce_stage + for(int i_stage = 0; i_stage < reduce_stage; i_stage++) + { + int src_lane = __lane_id() ^ (1 << i_stage); + int32_t v_remote_tmp = + __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(v_local)); + T v_remote = bit_cast(v_remote_tmp); + v_local = reduce_f(v_local, v_remote); + } + return v_local; + } + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const { return row * total_col + col; @@ -257,37 +446,37 @@ struct MoeSortingKernel index_t* shared_mem = reinterpret_cast(smem); index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) - index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) + index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1) for(int i = 0; i < num_experts; ++i) { - tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; + tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0; } #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; + ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])]; } __syncthreads(); #if 1 if(tid < num_experts) { - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; index_t local_c[8]; index_t prev_c = 0; // TODO: manually unroll. pragma unroll does not work well when we have dependency - for(int i = 1; i <= static_cast(blockDim.x); i+= 8) + for(int i = 1; i <= static_cast(blockDim.x); i += 8) { - local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; - local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; - local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; - local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; - local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; - local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; - local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; - local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; + local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)]; + local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)]; + local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)]; + local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)]; + local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)]; + local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)]; + local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)]; + local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)]; local_c[0] += prev_c; local_c[1] += local_c[0]; @@ -299,51 +488,57 @@ struct MoeSortingKernel local_c[7] += local_c[6]; prev_c = local_c[7]; - tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; - tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; - tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; - tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; - tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; - tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; - tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; - tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; + tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0]; + tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1]; + tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2]; + tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3]; + tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4]; + tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5]; + tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6]; + tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7]; } } #else - // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic + // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future + // heuristic { if(tid < num_experts) - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; - for(int i = 0; i < num_experts; i+=8) { + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; + for(int i = 0; i < num_experts; i += 8) + { index_t local_c[8]; - #pragma unroll - for(int j = 0; j < 8; j++) { - local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)]; } - #pragma unroll - for(int j = 0; j < 8; j++) { +#pragma unroll + for(int j = 0; j < 8; j++) + { wave_cumsum(local_c[j]); } - #pragma unroll - for(int j = 0; j < 8; j++) { - tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j]; } } } #endif __syncthreads(); - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { if(tid == 0) { cumsum[0] = 0; for(int i = 1; i <= num_experts; ++i) { auto current_units = [&]() { - index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + - unit_size_mdiv.divisor - 1; + index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; index_t y_ = unit_size_mdiv.div(x_); return max(y_, 1) * unit_size_mdiv.divisor; }(); @@ -351,20 +546,24 @@ struct MoeSortingKernel } *p_total_tokens_post_pad = cumsum[num_experts]; } - } else { - // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) - // for simplicity, not check experts here. - int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + } + else + { + // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= + // expert) for simplicity, not check experts here. + int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; - int local_cumsum = padded_tokens_per_expert; + int local_cumsum = padded_tokens_per_expert; wave_cumsum(local_cumsum); - if(tid == (num_experts - 1)) { - cumsum[0] = 0; + if(tid == (num_experts - 1)) + { + cumsum[0] = 0; *p_total_tokens_post_pad = local_cumsum; } - if(tid < num_experts) { + if(tid < num_experts) + { cumsum[tid + 1] = local_cumsum; } } @@ -373,7 +572,7 @@ struct MoeSortingKernel if(tid < num_experts) { int e_start = cumsum[tid]; - int e_end = cumsum[tid + 1]; + int e_end = cumsum[tid + 1]; for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) { p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; @@ -383,8 +582,8 @@ struct MoeSortingKernel #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - index_t expert_id = topk_id[i]; - index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; + index_t expert_id = topk_id[i]; + index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)]; index_t rank_post_pad = local_cnt + cumsum[expert_id]; #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID uint32_t curr_token_id, curr_topk_id; @@ -393,16 +592,17 @@ struct MoeSortingKernel #else p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); #endif - p_sorted_weights[rank_post_pad] = weights[i]; - tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; + p_sorted_weights[rank_post_pad] = weights[i]; + tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1; } - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { const index_t prefill_token = topk_mdiv.div(numel); if(tid < num_experts) { index_t expert_offset = - cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; index_t expert_end = cumsum[tid + 1]; while(expert_offset < expert_end) { @@ -417,16 +617,19 @@ struct MoeSortingKernel } } } - else { + else + { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; { - index_t eid = tid / experts_per_wave; - index_t expert_offset = - cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; + index_t eid = tid / experts_per_wave; + index_t expert_offset = cumsum[eid] + + tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] + + tid % experts_per_wave; index_t expert_end = cumsum[eid + 1]; - if(eid < num_experts) { + if(eid < num_experts) + { while(expert_offset < expert_end) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID @@ -436,10 +639,363 @@ struct MoeSortingKernel p_sorted_token_ids[expert_offset] = prefill_token; #endif p_sorted_weights[expert_offset] = static_cast(0.0); - expert_offset+=experts_per_wave; + expert_offset += experts_per_wave; } } - } + } + } + } + + // only support index_t, and single pixel access + struct simple_smem_indexer + { + index_t* smem; + index_t row_stride; + + // this is 2D + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_) + : smem(smem_), row_stride(row_stride_) + { + } + CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const + { + return smem[i_row * row_stride + i_col]; + } + CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col) + { + return smem[i_row * row_stride + i_col]; + } + + // this is 1D or linear + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {} + CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; } + CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; } + }; + + CK_TILE_DEVICE void + moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id, + const WeightType* __restrict__ weights, + const IndexType* __restrict__ local_expert_mask, + index_t* p_sorted_token_ids, + WeightType* p_sorted_weights, + index_t* p_sorted_expert_ids, + index_t* p_total_tokens_post_pad, + const index_t num_experts, + const index_t tokens, + const mdiv unit_size_mdiv, + const mdiv topk_mdiv, + const mdiv expert_mdiv, + const index_t smem_rows, + void* smem) const + { + const index_t tid = static_cast(threadIdx.x); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize); + const index_t lid = __lane_id(); + constexpr index_t block_size = 256; // blockDim.x; + const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; + const index_t topk = topk_mdiv.divisor; + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + const index_t smem_cols = num_experts + 1; + + simple_smem_indexer smem_cumsum{reinterpret_cast(smem) + 0}; + simple_smem_indexer smem_cumdup{reinterpret_cast(smem) + smem_cols}; + simple_smem_indexer smem_tokens{reinterpret_cast(smem) + 2 * smem_cols, + smem_cols}; + + // #pragma unroll 8 + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + // NOTE: below for loop can't have barrier inside!! + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id, curr_topk_id; + topk_mdiv.divmod(i, curr_token_id, curr_topk_id); + int i_t = i_token + curr_token_id; + + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + + if constexpr(Problem::SubTokenOneShot) + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; + else + smem_tokens(curr_token_id, eid)++; + } + __builtin_amdgcn_s_waitcnt(0xc07f); + } + __syncthreads(); // make sure different i_token iteration not overlap by different wave + } + + // counting + if(tid == 0) + { + smem_cumsum(0) = 0; + // smem_cumdup(0) = 0; + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + + for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm) + { + index_t local_c[Problem::SubTokenTile]; + index_t cnt = 0; + + for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile) + { +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e); + if constexpr(Problem::SubTokenOneShot) + { + local_c[j] = local_c[j] != 0 ? 1 : 0; + } + } + +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + cnt += wave_reduce(local_c[j], f_sum, number<8>{}); + } + } + if(lane_group_os == 0) + smem_cumsum(i_e + 1) = cnt; + } + } + + if constexpr(Problem::LocalExpertMasking) + { + smem_cumdup(0) = 0; + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + // reuse this buffer + smem_cumdup(i_e + 1) = local_expert_mask[i_e]; + } + } + + __syncthreads(); + + { + if(wid == 0) + { + // NOTE: under this block can never use __syncthreads! + int i_e_ = 0; + int local_cumsum_ = 0; + for(; i_e_ < num_experts; i_e_ += warpSize) + { + int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); + int local_cnt = smem_cumsum(i_e_ + lid + 1); + int blocks_pers_expert = + unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); + + int pre_cumsum_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(lid == 0 ? i_e_ : 0); + else + return 0; // not used + }(); + int local_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(i_e_ + lid + 1); + else + return 0; // not used + }(); + int padded_tokens_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert * unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return local_masking ? x_ : 0; + } + else + return x_; + }(); + + local_cumsum_ = padded_tokens_per_expert; + local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local + // cumsum padded in case local cumsum is zero, but + // pre_sumsum has value, which will result int + // zero local cumsum(but we want at least padded) + wave_cumsum(local_cumsum_); + + if((i_e_ + lid) < num_experts) + smem_cumsum(i_e_ + lid + 1) = local_cumsum_; + + if constexpr(Problem::LocalExpertMasking) + { + local_masking += pre_cumsum_masking; + wave_cumsum(local_masking); + if((i_e_ + lid) < num_experts) + smem_cumdup(i_e_ + lid + 1) = local_masking; + } + + // NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt() + // for above write however __syncthreads will cause barrier with waves other + // than 0(which is not we want) + __builtin_amdgcn_s_waitcnt(0xc07f); + } + if((lid + i_e_ - warpSize) == (num_experts - 1)) + { + *p_total_tokens_post_pad = local_cumsum_; + } + } + __syncthreads(); + } + + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + int e_start = smem_cumsum(i_e); + int e_end = smem_cumsum(i_e + 1); + + int expert_id = [&]() { + if constexpr(Problem::LocalExpertMasking) + { + // local expert id from cumsum + return smem_cumdup(i_e); + } + else + return i_e; + }(); + + smem_cumdup(i_e) = e_start; // duplicate cumsum for later use + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[i_e] == 0) + continue; + } + + for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) + { + p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; + } + } + smem_cumdup(num_experts) = smem_cumsum(num_experts); + + // fill the p_sorted_token_ids/p_sorted_weights + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + if constexpr(!Problem::SubTokenOneShot) + { + // clear every time + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + // load again + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id_, curr_topk_id_; + topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_); + int curr_token_id = static_cast(curr_token_id_); + int curr_topk_id = static_cast(curr_topk_id_); + int i_t = i_token + curr_token_id; + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1 + } + } + __syncthreads(); + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm) + { + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[eid] == 0) + continue; + } + int position = smem_cumsum(eid); + for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens; + i_sub_token += lane_group_sz) + { + auto x = smem_tokens(i_sub_token, eid); + + int local_cnt_cache = x != 0 ? 1 : 0; + int local_cnt = local_cnt_cache; + wave_cumsum(local_cnt); + if(x != 0) + { + // now x is topk value +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[position + local_cnt - 1] = + MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1); +#else + p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token; +#endif + p_sorted_weights[position + local_cnt - 1] = + weights[(i_token + i_sub_token) * topk + x - 1]; + } + + int remote_cnt = __builtin_amdgcn_ds_bpermute( + (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt); + + position += remote_cnt; + } + smem_cumsum(eid) = position; + } + } + __syncthreads(); + } + + // add the skip number + for(int eid = tid; eid < num_experts; eid += block_size) + { + int e_start = smem_cumsum(eid); + int e_end = smem_cumdup(eid + 1); + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + while(e_start < e_end) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk); +#else + p_sorted_token_ids[e_start] = tokens; +#endif + p_sorted_weights[e_start] = static_cast(0.0); + e_start++; + } } } @@ -456,6 +1012,24 @@ struct MoeSortingKernel } const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; extern __shared__ char smem[]; +#if MOE_SORTING_USE_EX_KERNEL + (void)numel; + return moe_align_block_size_kernel_ex( + static_cast(kargs.p_topk_ids), + static_cast(kargs.p_weights), + static_cast(kargs.p_local_expert_mask), + static_cast(kargs.p_sorted_token_ids), + static_cast(kargs.p_sorted_weights), + static_cast(kargs.p_sorted_expert_ids), + static_cast(kargs.p_total_tokens_post_pad), + kargs.num_experts, + kargs.tokens, + kargs.unit_size_mdiv, + kargs.topk_mdiv, + kargs.expert_mdiv, + kargs.smem_rows, + smem); +#else return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), static_cast(kargs.p_sorted_token_ids), @@ -468,6 +1042,7 @@ struct MoeSortingKernel kargs.unit_size_mdiv, kargs.topk_mdiv, smem); +#endif } }; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp new file mode 100644 index 0000000000..15effe7118 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template +struct MoeSortingProblem +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = + InternalLoadUnroll_; // TODO: need better design(like tile size) + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +template +struct MoeSortingProblemEx +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t SubTokenTile = SubTokenTile_; + static constexpr bool SubTokenOneShot = SubTokenOneShot_; + static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8); + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp deleted file mode 100644 index 50005c4402..0000000000 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include -#include - -namespace ck_tile { - -template -struct MoeSortingProblem -{ - // TODO: this kernel only support warp per row - using WeightType = remove_cvref_t; - using IndexType = remove_cvref_t; - - static constexpr index_t WarpSize = get_warp_size(); - static constexpr index_t WarpsPerBlock = 1; - static constexpr index_t InternalLoadUnroll = - InternalLoadUnroll_; // TODO: need better design(like tile size) - static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out -}; -} // namespace ck_tile