mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] moe sorting optimize local_token (#2469)
* fix bug in loops that need use local tokens to compute * support extra chain local_token * update * update * refine some main * update * support dispatch_policy * fix 15 example
This commit is contained in:
@@ -6,7 +6,8 @@
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
@@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
auto a0 = fused_moesorting_args
|
||||
{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token, o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) *
|
||||
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
Reference in New Issue
Block a user