mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
support skip empty tokens for expert sorting
This commit is contained in:
@@ -3,6 +3,12 @@
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
@@ -17,6 +23,24 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
@@ -38,11 +62,13 @@
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
(void)c_;
|
||||
if(is_sub_token_onshot)
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, true);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, true);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, true);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, false);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, false);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, false);
|
||||
}
|
||||
}
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
const index_t unit_size,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
@@ -33,8 +34,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
#endif
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
// count number of unit-size slices in this expert
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
// count the tokens used in this expert
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
// TODO: above 2 buffer seems duplicated
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
@@ -74,6 +78,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
if(skip_experts_with_zero_token)
|
||||
{
|
||||
if(expert_slice_idxs[e] == 0)
|
||||
continue;
|
||||
}
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
|
||||
@@ -692,8 +692,6 @@ struct MoeSortingKernel
|
||||
|
||||
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
|
||||
{
|
||||
#if 1
|
||||
// __syncthreads();
|
||||
// NOTE: below for loop can't have barrier inside!!
|
||||
for(int i = tid; i < (sub_tokens * topk); i += block_size)
|
||||
{
|
||||
@@ -716,7 +714,6 @@ struct MoeSortingKernel
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
//
|
||||
}
|
||||
__syncthreads(); // make sure different i_token iteration not overlap by different wave
|
||||
// if(tid == 0) {
|
||||
@@ -740,30 +737,6 @@ struct MoeSortingKernel
|
||||
// e0+e1+e2+e3+e4+e5+e6+e7
|
||||
// );
|
||||
// }
|
||||
|
||||
#else
|
||||
int i = tid;
|
||||
while(true)
|
||||
{
|
||||
__syncthreads();
|
||||
if(i >= (sub_tokens * topk))
|
||||
break;
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
int i_t = i_token + curr_token_id;
|
||||
// printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d,
|
||||
// tokens:%d\n",
|
||||
// i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
|
||||
if(i_t < tokens)
|
||||
{
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
|
||||
i += block_size;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
// counting
|
||||
@@ -919,8 +892,18 @@ struct MoeSortingKernel
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
int blocks_pers_expert =
|
||||
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
|
||||
int padded_tokens_per_expert =
|
||||
max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
|
||||
int padded_tokens_per_expert = [&]() {
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
// if local_cnt is zero, blocks_pers_expert will be zero
|
||||
return blocks_pers_expert * unit_size_mdiv.divisor;
|
||||
}
|
||||
else
|
||||
{
|
||||
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
|
||||
}
|
||||
}();
|
||||
|
||||
local_cumsum_ = padded_tokens_per_expert;
|
||||
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
@@ -952,6 +935,12 @@ struct MoeSortingKernel
|
||||
int e_end = smem_cumsum(i_e + 1);
|
||||
// printf("i_e:%d, e_start:%d, e_end:%d\n", i_e, e_start, e_end);
|
||||
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
continue;
|
||||
}
|
||||
|
||||
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = i_e;
|
||||
@@ -1069,6 +1058,11 @@ struct MoeSortingKernel
|
||||
int e_start = smem_cumsum(eid);
|
||||
int e_end = smem_cumdup(eid + 1);
|
||||
// printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end);
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
continue;
|
||||
}
|
||||
while(e_start < e_end)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
|
||||
@@ -30,17 +30,19 @@ template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
|
||||
bool SubTokenOneShot_, // if we only loop over once or not
|
||||
index_t ExpertTile_ = 0>
|
||||
bool SkipExpertsWithZeroTokens_ = true,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblemEx
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user