mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
[CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153)
* combine 2-3 as single stage
* support zeroing
* improve long tokens
* update specialization
* b16 ws
* 8bit topk optimize
* update 15 example
[ROCm/composable_kernel commit: 4e9b76f88c]
This commit is contained in:
@@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
Reference in New Issue
Block a user