[CK_TILE] add moe-sorting MP kernel (#1910)

* moe sorting ex

* fix bug for race condition

* fix bug and optimze large expert

* fix

* optimize with sub_token_oneshot

* support skip empty tokens for expert sorting

* update moe_sorting

* tidy code

* support mp kernel

* hint mp

* remove use less code

* porting to example 15

---------

Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
carlushuang
2025-02-25 17:56:55 +08:00
committed by GitHub
parent 020148d0f7
commit 353a612b44
8 changed files with 1043 additions and 31 deletions

View File

@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem num_sorted_tiles_buf(
num_sorted_tiles_host.get_element_space_size_in_bytes());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!
fused_moe_traits traits{prec_i,
prec_w,
prec_o,
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
o_buf.GetDeviceBuffer(),
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),