[CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153)

* combine 2-3 as single stage

* support zeroing

* improve long tokens

* update specialization

* b16 ws

* 8bit topk optimize

* update 15 example
This commit is contained in:
carlushuang
2025-05-06 17:32:07 +08:00
committed by GitHub
parent b8fa27bfef
commit 4e9b76f88c
15 changed files with 1216 additions and 115 deletions

View File

@@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.get_element_space_size_in_bytes());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
ck_tile::index_t workspace_size =
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!