support skip empty tokens for expert sorting

This commit is contained in:
carlushuang
2025-01-29 17:31:08 +08:00
parent 3eabee6b36
commit 001eb68d44
4 changed files with 114 additions and 35 deletions

View File

@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
@@ -38,11 +62,13 @@
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH(4);
}
}
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
(void)c_;
if(is_sub_token_onshot)
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, true);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, true);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, true);
}
else
{
MOE_SORTING_DISPATCH_(1, true);
}
}
else
{
if(r_ % 8 == 0)
{
MOE_SORTING_DISPATCH_(8, false);
}
else if(r_ % 4 == 0)
{
MOE_SORTING_DISPATCH_(4, false);
}
else if(r_ % 2 == 0)
{
MOE_SORTING_DISPATCH_(2, false);
}
else
{
MOE_SORTING_DISPATCH_(1, false);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return -1;
}

View File

@@ -19,7 +19,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size)
const index_t unit_size,
bool skip_experts_with_zero_token = true)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
@@ -33,8 +34,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
// count number of unit-size slices in this expert
std::vector<IndexType> expert_slices(experts, 1);
// count the tokens used in this expert
std::vector<IndexType> expert_slice_idxs(experts, 0);
// TODO: above 2 buffer seems duplicated
for(index_t t = 0; t < num_token; t++)
{
@@ -74,6 +78,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++)
{
if(skip_experts_with_zero_token)
{
if(expert_slice_idxs[e] == 0)
continue;
}
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,

View File

@@ -692,8 +692,6 @@ struct MoeSortingKernel
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
{
#if 1
// __syncthreads();
// NOTE: below for loop can't have barrier inside!!
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
@@ -716,7 +714,6 @@ struct MoeSortingKernel
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
//
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
// if(tid == 0) {
@@ -740,30 +737,6 @@ struct MoeSortingKernel
// e0+e1+e2+e3+e4+e5+e6+e7
// );
// }
#else
int i = tid;
while(true)
{
__syncthreads();
if(i >= (sub_tokens * topk))
break;
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
// printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d,
// tokens:%d\n",
// i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid)++;
}
i += block_size;
}
__syncthreads();
#endif
}
// counting
@@ -919,8 +892,18 @@ struct MoeSortingKernel
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert =
max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int padded_tokens_per_expert = [&]() {
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
// if local_cnt is zero, blocks_pers_expert will be zero
return blocks_pers_expert * unit_size_mdiv.divisor;
}
else
{
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
}
}();
local_cumsum_ = padded_tokens_per_expert;
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
@@ -952,6 +935,12 @@ struct MoeSortingKernel
int e_end = smem_cumsum(i_e + 1);
// printf("i_e:%d, e_start:%d, e_end:%d\n", i_e, e_start, e_end);
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = i_e;
@@ -1069,6 +1058,11 @@ struct MoeSortingKernel
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
// printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
continue;
}
while(e_start < e_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID

View File

@@ -30,17 +30,19 @@ template <typename IndexType_,
typename WeightType_,
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
bool SubTokenOneShot_, // if we only loop over once or not
index_t ExpertTile_ = 0>
bool SkipExpertsWithZeroTokens_ = true,
index_t ExpertTile_ = 0>
struct MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t SubTokenTile = SubTokenTile_;
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
};