From 10a288c3a296c3ba7719edad0709f811bcabbc37 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Thu, 18 Sep 2025 08:57:20 +0800 Subject: [PATCH] opt moe sorting (#2822) * opt moe storing for 2k * rm duplicated clear --------- Co-authored-by: root --- example/ck_tile/13_moe_sorting/CMakeLists.txt | 2 +- .../13_moe_sorting/moe_sorting_api.cpp | 239 ++++++++++++---- .../instances/fused_moesorting_api.cpp | 257 ++++++++++++++---- .../fused_moe/kernel/moe_sorting_kernel.hpp | 130 ++++++++- 4 files changed, 511 insertions(+), 117 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt index 09f3e4ac4e..1e89b730d7 100644 --- a/example/ck_tile/13_moe_sorting/CMakeLists.txt +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -3,6 +3,6 @@ target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOUR set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-error) # list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS}) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index efb01af009..d69adbcb92 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -194,7 +194,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ @@ -205,7 +205,25 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi unroll_num, \ expert_masking, \ local_token>; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ @@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -294,7 +352,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ } \ @@ -303,7 +362,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ } \ @@ -315,7 +375,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ } \ @@ -324,7 +385,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = ck_tile::launch_kernel( \ s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ } \ @@ -365,67 +427,136 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co } }; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + if (a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 3b434bd538..5b2bacf7e5 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -198,7 +198,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ @@ -209,7 +209,25 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til unroll_num, \ expert_masking, \ local_token>; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ const dim3 blocks = kernel::BlockSize(a); \ @@ -290,14 +308,14 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ if(is_local_token) \ { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ } \ @@ -305,7 +323,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ } \ @@ -316,7 +334,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ } \ @@ -324,7 +342,51 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = ck_tile::launch_kernel( \ s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ } \ @@ -340,65 +402,154 @@ float fused_moesorting_mp(fused_moesorting_trait t, using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + + if (a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 00853714c6..976ecf7161 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1373,7 +1373,7 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe { index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) + impl::moe_sorting_mp_cumsum_smem_size(num_experts_) -#if 1 +#if MOE_SORTING_FUSE_MP_01 + impl::moe_sorting_mp_sem_smem_size(); #else ; @@ -1552,7 +1552,123 @@ p_m_cumsum // count topk_id into mesh template -struct MoeSortingMultiPhaseKernel_P0 +struct MoeSortingMultiPhaseKernel_P0_v1 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens + void* p_expert_mesh; // [expert, tokens] + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens + // used for ws/LDS calculation + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.tokens = h.tokens; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using topk_id_t = ext_vector_t; + + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + index_t rounded_tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile * + Problem::SubTokenTile; + } + else + return tokens; + }(); + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + +#pragma unroll Problem::SubTokenTile + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += gridDim.x * BLOCK_SIZE) + { + auto x = p_topk_ids[i]; + static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { + IndexType eid = x[j.value]; // ext_vector_type must use int to [] + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); + if constexpr(Problem::LocalToken) + { + if(static_cast(curr_token_id) < tokens) + p_expert_mesh[eid * mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; + } + else + p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; + }); + } + } +}; +template +struct MoeSortingMultiPhaseKernel_P0_v2 { using Problem = remove_cvref_t; @@ -1577,7 +1693,7 @@ struct MoeSortingMultiPhaseKernel_P0 mdiv topk_mdiv; const void* p_local_expert_mask; // [expert] - void* p_expert_cumsum; + void* p_expert_cumsum; // [expert] index_t num_experts; }; @@ -1669,12 +1785,6 @@ struct MoeSortingMultiPhaseKernel_P0 mask = p_local_expert_mask[eid]; } MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride; - // const index_t total_bytes = mesh_stride * 4; - - // using vector_type = ext_vector_t; - // auto zero_ = vector_type{0}; - - // vector_type* p_expert_mesh_clear = reinterpret_cast(kargs.p_expert_mesh); for(index_t i = threadIdx.x; i < mesh_stride; i += BLOCK_SIZE) { p_expert_mesh[i] = 0; @@ -1750,6 +1860,8 @@ struct MoeSortingMultiPhaseKernel_P0 } }; + + // cnt total tokens for a expert template struct MoeSortingMultiPhaseKernel_P1