mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] add moe-sorting MP kernel (#1910)
* moe sorting ex * fix bug for race condition * fix bug and optimze large expert * fix * optimize with sub_token_oneshot * support skip empty tokens for expert sorting * update moe_sorting * tidy code * support mp kernel * hint mp * remove use less code * porting to example 15 --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user