diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index d614b8462a..00c6be8f10 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - constexpr bool local_token = local_token_; \ - using ms_problem = ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ @@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ @@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ @@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = ck_tile::launch_kernel( \ s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ @@ -369,69 +427,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co } }; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + if(a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_1(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_1(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 441aa84edf..5edb74f52f 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -198,22 +198,40 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - constexpr bool local_token = local_token_; \ - using ms_problem = ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -290,6 +308,46 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -297,7 +355,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ @@ -306,7 +364,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ @@ -318,7 +376,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ @@ -327,7 +385,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = ck_tile::launch_kernel( \ s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ @@ -344,67 +402,156 @@ float fused_moesorting_mp(fused_moesorting_trait t, using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + + if(a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_1(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_1(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 09c2510d3e..2918cd33bc 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -20,7 +20,7 @@ namespace ck_tile { #endif #ifndef MOE_SORTING_FUSE_MP_01 -#define MOE_SORTING_FUSE_MP_01 0 +#define MOE_SORTING_FUSE_MP_01 1 #endif // weather use 2d buffer indexing for fmoe ws or 1d @@ -527,7 +527,7 @@ struct MoeSortingKernel } __syncthreads(); -#if 1 +#if MOE_SORTING_FUSE_MP_01 if(tid < num_experts) { tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; @@ -1322,18 +1322,18 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) } } -template +template CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid) { - // const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; - long_index_t offset = static_cast(gid) * BLOCK_SIZE + threadIdx.x; + // const index_t offset = (blockIdx.x - 1) * kBlockSize + threadIdx.x; + long_index_t offset = static_cast(gid) * kBlockSize + threadIdx.x; if(offset < buf_bytes / 16) { buf[offset] = uint8x16_t{0}; } } -template +template CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d( void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks) { @@ -1345,7 +1345,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d( vector_type* p_buf = reinterpret_cast(buf); auto zero_ = vector_type{0}; - for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE) + for(long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize) { p_buf[i] = zero_; } @@ -1552,7 +1552,7 @@ p_m_cumsum // count topk_id into mesh template -struct MoeSortingMultiPhaseKernel_P0 +struct MoeSortingMultiPhaseKernel_P0_v1 { using Problem = remove_cvref_t; @@ -1673,6 +1673,197 @@ struct MoeSortingMultiPhaseKernel_P0 } } }; +template +struct MoeSortingMultiPhaseKernel_P0_v2 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t kBlockSize = 512; + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens + void* p_expert_mesh; // [expert, tokens] + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens + // used for ws/LDS calculation + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv topk_mdiv; + + const void* p_local_expert_mask; // [expert] + void* p_expert_cumsum; // [expert] + index_t num_experts; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.tokens = h.tokens; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.p_local_expert_mask = h.p_local_expert_mask; + k.num_experts = h.num_experts; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return h.num_experts; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); } + + // in byte + // CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() + { + return kBlockSize / get_warp_size() * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + __shared__ char smem[GetSmemSize()]; + using topk_id_t = ext_vector_t; + const int eid = blockIdx.x; + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + index_t lane_id = threadIdx.x % get_warp_size(); + index_t wave_id = threadIdx.x / get_warp_size(); + const index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + index_t rounded_tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return (tokens + index_pack - 1) / index_pack * index_pack; + } + else + return tokens; + }(); + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + + IndexType mask = 1; + if constexpr(Problem::LocalExpertMasking) + { + mask = p_local_expert_mask[eid]; + } + MeshType* p_expert_mesh = + reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride; + for(index_t i = threadIdx.x; i < mesh_stride; i += kBlockSize) + { + p_expert_mesh[i] = 0; + } + ck_tile::block_sync_load_raw(0); + + index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / index_pack; + +#pragma unroll index_pack + for(index_t i = threadIdx.x; i < total_elem; i += kBlockSize) + { + auto x = p_topk_ids[i]; + static_for<0, index_pack, 1>{}([&](auto j) { + IndexType eid_x = x[j.value]; // ext_vector_type must use int to [] + if(eid_x == eid) + { + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod(i * index_pack + j, curr_token_id, curr_topk_id); + if constexpr(Problem::LocalToken) + { + if(static_cast(curr_token_id) < tokens) + p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff; + } + else + p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff; + } + }); + } + ck_tile::block_sync_load_raw(0); + + { + + using r_t = ext_vector_t; // always use int32x4 + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + const r_t* p_expert_mesh_r = reinterpret_cast(p_expert_mesh); + + int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize; + + if(Problem::LocalToken && mask == 0) + return; // skip + index_t cnt = 0; // per-wave cnt + for(int i = 0; i < loops; i++) + { + int position = i * kBlockSize + threadIdx.x; + r_t v{0}; + if(position < (mesh_stride / index_pack)) + v = p_expert_mesh_r[position]; + index_t local_sum = 0; + static_for<0, index_pack, 1>{}( + [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; }); + cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); + } + + // reduce cross wave + IndexType* s = reinterpret_cast(smem); + if(lane_id == 0) + { + s[wave_id] = cnt; + } + __syncthreads(); + + if(threadIdx.x == 0) + { + index_t c = 0; + for(auto i = 0; i < (kBlockSize / get_warp_size()); i++) + { + c += s[i]; + } + p_expert_cumsum[eid] = c; + } + } + } +}; // cnt total tokens for a expert template diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp index 0cf600d2b4..11ccdef69e 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_api.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - constexpr bool local_token = local_token_; \ - using ms_problem = ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ @@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ @@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ @@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = ck_tile::launch_kernel( \ s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ @@ -368,70 +426,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co } } }; - - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + if(!ck_tile::is_gfx12_supported() && a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_1(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_1(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } }