mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] moe sorting ex kernel to support expert > 128 (#1840)
* moe sorting ex * fix bug for race condition * fix bug and optimze large expert * fix * optimize with sub_token_oneshot * support skip empty tokens for expert sorting * update moe_sorting * tidy code
This commit is contained in:
@@ -14,12 +14,15 @@ namespace ck_tile {
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
const HostTensor<IndexType>& local_expert_mask,
|
||||
HostTensor<IndexType>& p_sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
const index_t unit_size,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
#endif
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
// count number of unit-size slices in this expert
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
// count the tokens used in this expert
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
// TODO: above 2 buffer seems duplicated
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
IndexType* out_tokens = p_sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
int curr_expert_id = 0;
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
if(local_expert_masking)
|
||||
{
|
||||
if(local_expert_mask(e) == 0)
|
||||
continue;
|
||||
}
|
||||
if(skip_experts_with_zero_token)
|
||||
{
|
||||
if(expert_slice_idxs[e] == 0)
|
||||
{
|
||||
curr_expert_id++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = e;
|
||||
out_expert_id[s] = curr_expert_id;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
curr_expert_id++;
|
||||
}
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user