diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7ca24c5c9a..805cd54878 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,6 +3,12 @@ #include "fused_moesorting.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,24 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +62,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + (void)c_; + if(is_sub_token_onshot) + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, true); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, true); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, true); + } + else + { + MOE_SORTING_DISPATCH_(1, true); + } + } + else + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, false); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, false); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, false); + } + else + { + MOE_SORTING_DISPATCH_(1, false); + } + } + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 3851629cc2..b144186638 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -19,7 +19,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, HostTensor& sorted_expert_ids, index_t& unit_cnt, const index_t experts, - const index_t unit_size) + const index_t unit_size, + bool skip_experts_with_zero_token = true) { const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t topk = topk_ids.mDesc.get_lengths()[1]; @@ -33,8 +34,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, #endif std::vector> expert_token_weights( experts, std::vector(unit_size, 0)); + // count number of unit-size slices in this expert std::vector expert_slices(experts, 1); + // count the tokens used in this expert std::vector expert_slice_idxs(experts, 0); + // TODO: above 2 buffer seems duplicated for(index_t t = 0; t < num_token; t++) { @@ -74,6 +78,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, IndexType* out_expert_id = sorted_expert_ids.data(); for(index_t e = 0; e < experts; e++) { + if(skip_experts_with_zero_token) + { + if(expert_slice_idxs[e] == 0) + continue; + } memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); out_tokens += expert_slices[e] * unit_size; memcpy(out_weights, diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 58b3f41f55..9e26ed2d84 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -692,8 +692,6 @@ struct MoeSortingKernel for(int i_token = 0; i_token < tokens; i_token += sub_tokens) { -#if 1 - // __syncthreads(); // NOTE: below for loop can't have barrier inside!! for(int i = tid; i < (sub_tokens * topk); i += block_size) { @@ -716,7 +714,6 @@ struct MoeSortingKernel smem_tokens(curr_token_id, eid)++; } __builtin_amdgcn_s_waitcnt(0xc07f); - // } __syncthreads(); // make sure different i_token iteration not overlap by different wave // if(tid == 0) { @@ -740,30 +737,6 @@ struct MoeSortingKernel // e0+e1+e2+e3+e4+e5+e6+e7 // ); // } - -#else - int i = tid; - while(true) - { - __syncthreads(); - if(i >= (sub_tokens * topk)) - break; - uint32_t curr_token_id, curr_topk_id; - topk_mdiv.divmod(i, curr_token_id, curr_topk_id); - int i_t = i_token + curr_token_id; - // printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d, - // tokens:%d\n", - // i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens); - if(i_t < tokens) - { - int eid = topk_id[i_t * topk + curr_topk_id]; - smem_tokens(curr_token_id, eid)++; - } - - i += block_size; - } - __syncthreads(); -#endif } // counting @@ -919,8 +892,18 @@ struct MoeSortingKernel int local_cnt = smem_cumsum(i_e_ + lid + 1); int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); - int padded_tokens_per_expert = - max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + int padded_tokens_per_expert = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + return blocks_pers_expert * unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + } + }(); + local_cumsum_ = padded_tokens_per_expert; local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local // cumsum padded in case local cumsum is zero, but @@ -952,6 +935,12 @@ struct MoeSortingKernel int e_end = smem_cumsum(i_e + 1); // printf("i_e:%d, e_start:%d, e_end:%d\n", i_e, e_start, e_end); smem_cumdup(i_e) = e_start; // duplicate cumsum for later use + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) { p_sorted_expert_ids[unit_size_mdiv.div(i)] = i_e; @@ -1069,6 +1058,11 @@ struct MoeSortingKernel int e_start = smem_cumsum(eid); int e_end = smem_cumdup(eid + 1); // printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end); + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } while(e_start < e_end) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index f35bd07bd3..5e1735a8ac 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -30,17 +30,19 @@ template + bool SkipExpertsWithZeroTokens_ = true, + index_t ExpertTile_ = 0> struct MoeSortingProblemEx { // TODO: this kernel only support warp per row using WeightType = remove_cvref_t; using IndexType = remove_cvref_t; - static constexpr index_t WarpSize = get_warp_size(); - static constexpr index_t WarpsPerBlock = 1; - static constexpr index_t SubTokenTile = SubTokenTile_; - static constexpr bool SubTokenOneShot = SubTokenOneShot_; + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t SubTokenTile = SubTokenTile_; + static constexpr bool SubTokenOneShot = SubTokenOneShot_; + static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8); static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out };