diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index e59fcaedad..ce689a370c 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -153,9 +153,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts); + ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); - if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 109ec1b157..305cf118d2 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -153,7 +161,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) { return moe_sorting_mp(t, a, s); } @@ -171,57 +179,107 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ - }() +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { @@ -230,29 +288,74 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(t.local_expert_masking) + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, true), - MOE_SORTING_MP_1(1, true), - MOE_SORTING_MP_2(1, true), - MOE_SORTING_MP_3(1, true)); - return ave_time; +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif } else { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, false), - MOE_SORTING_MP_1(1, false), - MOE_SORTING_MP_2(1, false), - MOE_SORTING_MP_3(1, false)); - return ave_time; + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } return -1; } -int moe_sorting_get_workspace_size(int tokens, int num_experts) +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts); + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index b47ae9013b..0fe8d81e70 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -22,6 +22,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs // if return non zero, means need workspace, you need to allocate a GPU buffer // and set to moe_sorting_args.p_ws // NOTE: workspace size are required to clear zero before use the API -int moe_sorting_get_workspace_size(int tokens, int num_experts); +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index cf2c2e164b..fbfb10822c 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -26,3 +26,9 @@ $EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 $EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 $EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 $EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 +$EXE -t=128 -e=128 -k=6 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840 +$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840 \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index b354d1d347..46425384cc 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -56,4 +56,6 @@ struct fused_moe_traits bool local_expert_masking; // if mask experts as local expert }; +// if return zero, no ws needed +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp index a3ff8c5bf7..11e1c6e531 100644 --- a/example/ck_tile/15_fused_moe/fused_moesorting.hpp +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -18,4 +18,5 @@ struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs { }; +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index f887d57aa9..b3515b1bec 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -2,6 +2,12 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "fused_moe.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); +} float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) { diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7aedaa9317..0d83c48d02 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -107,6 +115,10 @@ } #endif +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s); + float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -153,18 +165,198 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til } } #else - using index_t = ck_tile::index_t; - using ms_weight_type = float; - auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); - auto sub_token_ = r_ - 2; - r_ = (r_ - 2) / 8; - bool is_sub_token_onshot = a.tokens <= sub_token_; + if(fused_moe_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) + { + return fused_moesorting_mp(t, a, s); + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts); + auto row_ = sub_token_ / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; - (void)c_; - MOE_SORTING_DISPATCH_EMASK_(r_); + MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); #endif } return -1; } + +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + }() + +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } + +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + using ms_index_t = ck_tile::index_t; + using ms_weight_type = float; + + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } + } + } + return -1; +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index cb93ce8907..da843891ce 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.get_element_space_size_in_bytes()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); + ck_tile::index_t workspace_size = + ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 821b3a8e84..b94157eaec 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/arch/workgroup_barrier.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/container_helper.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 09de5f325f..1d3cf5c010 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -154,4 +154,13 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres #pragma clang diagnostic pop } +CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() +{ +#if defined(__gfx950__) + return 163840; +#else + return 65536; +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/workgroup_barrier.hpp b/include/ck_tile/core/arch/workgroup_barrier.hpp new file mode 100644 index 0000000000..827a490fcb --- /dev/null +++ b/include/ck_tile/core/arch/workgroup_barrier.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +struct workgroup_barrier +{ + CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {} + + CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0) + { + return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED); + } + + CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) != value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) < value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(atomicCAS(base_ptr + offset, compare, value) != compare) {} + } + __syncthreads(); + } + + // enter critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); } + + // exit critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); } + + CK_TILE_DEVICE void inc(uint32_t offset = 0) + { + __syncthreads(); + if(threadIdx.x == 0) + { + atomicAdd(base_ptr + offset, 1); + } + } + + uint32_t* base_ptr; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 414509e479..27133fa847 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -257,5 +257,5 @@ #endif #ifndef CK_TILE_WA_ISSUE_2028 -#define CK_TILE_WA_ISSUE_2028 1 +#define CK_TILE_WA_ISSUE_2028 0 #endif diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 6a7ccd2472..664294fe18 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -19,6 +19,10 @@ namespace ck_tile { #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_FUSE_MP_01 +#define MOE_SORTING_FUSE_MP_01 0 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -118,7 +122,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here int smem_rows = [&](){ index_t target_occupancy_ = 2; - constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t total_ = get_smem_capacity() / sizeof(index_t); constexpr index_t sub_unroll = 8; constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt // at lease 2 lines, one for sub_token unroll, one for cumsum @@ -250,7 +254,7 @@ struct MoeSortingKernel { #if MOE_SORTING_USE_EX_KERNEL auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); - return smem_rows * smem_cols * sizeof(int); + return smem_rows * smem_cols * sizeof(index_t); #else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size @@ -1063,17 +1067,43 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) return (tokens + chunk - 1) / chunk * chunk; }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts) +// 4-i32 mesh, 2-i16 mseh, 1-i8 mesh +CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_, + index_t /*num_experts_*/, + index_t topk_) +{ + // small token case, let's run mesh with dword score board + if(tokens_ < 512) + return 4; + else + { + if(topk_ >= 255) + return 2; // 16bit mesh + else + return 1; // 8bit mesh if small enough + } +} + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens, + index_t num_experts, + index_t topk) { index_t row_size = moe_sorting_mp_mesh_stride(tokens); - return num_experts * row_size; + index_t elem = num_experts * row_size; + return elem * moe_sorting_mesh_byte_size(tokens, num_experts, topk); }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts) +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts) { constexpr index_t chunk = 32; index_t row_size = num_experts + 1; - return (row_size + chunk - 1) / chunk * chunk; + return (row_size + chunk - 1) / chunk * chunk * sizeof(index_t); +}; + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() +{ + constexpr index_t chunk = 32; + return chunk * sizeof(index_t); }; template @@ -1245,15 +1275,20 @@ CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) } // return size in byte -CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_) { - index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) + - impl::moe_sorting_mp_cumsum_elem(num_experts_); - return elem * sizeof(index_t); + index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) + + impl::moe_sorting_mp_cumsum_smem_size(num_experts_) +#if MOE_SORTING_FUSE_MP_01 + + impl::moe_sorting_mp_sem_smem_size(); +#else + ; +#endif + return s_; } // return size in byte -CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_) { #if 1 if(moe_sorting_is_oneshot(tokens_, num_experts_)) @@ -1262,10 +1297,10 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts } else { - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); } #else - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); #endif } @@ -1320,6 +1355,7 @@ struct MoeSortingMultiPhaseKernel_P0 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1371,22 +1407,21 @@ struct MoeSortingMultiPhaseKernel_P0 { using topk_id_t = ext_vector_t; - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); - IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x) + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += gridDim.x * BLOCK_SIZE) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; }); } } @@ -1400,6 +1435,7 @@ struct MoeSortingMultiPhaseKernel_P1 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1420,9 +1456,9 @@ struct MoeSortingMultiPhaseKernel_P1 Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); return k; @@ -1444,13 +1480,11 @@ struct MoeSortingMultiPhaseKernel_P1 int eid = blockIdx.x; - constexpr index_t index_pack = 4; // always packed - using r_t = ext_vector_t; // always use int32x4 + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 r_t* p_expert_mesh = reinterpret_cast( - reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); @@ -1502,6 +1536,197 @@ struct MoeSortingMultiPhaseKernel_P1 } }; +#if MOE_SORTING_FUSE_MP_01 +template +struct MoeSortingMultiPhaseKernel_P01 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_expert_sem; // [1] + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t wg_count; // used for semaphore + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_expert_sem = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) + + impl::moe_sorting_mp_cumsum_smem_size(h.num_experts)); + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.wg_count = WGCounts(h); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h) + { + index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile; + index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // no more than grid_size + return min(elem_cnt, GridSize(h)); + } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() + { + return BLOCK_SIZE / warpSize * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + workgroup_barrier wb{reinterpret_cast(kargs.p_expert_sem)}; + + { + using topk_id_t = ext_vector_t; + + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + +#pragma unroll Problem::SubTokenTile + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += BLOCK_SIZE * gridDim.x) + { + auto x = p_topk_ids[i]; + static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { + IndexType eid = x[j.value]; // ext_vector_type must use int to [] + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod( + i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + }); + } + if(static_cast(blockIdx.x) < kargs.wg_count) + { + wb.inc(); + } + } + + { + __shared__ char smem[GetSmemSize()]; + int eid = blockIdx.x; + + // early exist in case of extra atomic wait + if(eid >= kargs.num_experts) + return; + + wb.wait_lt(kargs.wg_count); + + for(; eid < kargs.num_experts; eid += gridDim.x) + { + // if(threadIdx.x == 0) + // printf("!!! bid:%d, eid:%d (%d, %d)\n", + // static_cast(blockIdx.x), + // eid, + // kargs.num_experts, + // static_cast(blockDim.x)); + constexpr index_t index_pack = 4; // always packed + using r_t = ext_vector_t; // always use int32x4 + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + + if constexpr(Problem::LocalExpertMasking) + { + IndexType mask = p_local_expert_mask[eid]; + if(mask == 0) + continue; // skip + } + + index_t cnt = 0; // per-wave cnt + for(int i = 0; i < loops; i++) + { + int position = i * BLOCK_SIZE + threadIdx.x; + r_t v{0}; + if(position < (kargs.mesh_stride / index_pack)) + v = p_expert_mesh[position]; + index_t local_sum = 0; + static_for<0, index_pack, 1>{}( + [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; }); + cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); + } + + index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / warpSize; + + // reduce cross wave + IndexType* s = reinterpret_cast(smem); + __syncthreads(); + if(lane_id == 0) + { + s[wave_id] = cnt; + } + __syncthreads(); + + if(threadIdx.x == 0) + { + index_t c = 0; + for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + { + c += s[i]; + } + p_expert_cumsum[eid] = c; + } + } + } + } +}; +#endif + // token count cumsum template struct MoeSortingMultiPhaseKernel_P2 @@ -1510,6 +1735,7 @@ struct MoeSortingMultiPhaseKernel_P2 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1536,10 +1762,9 @@ struct MoeSortingMultiPhaseKernel_P2 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; - // k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -1566,7 +1791,8 @@ struct MoeSortingMultiPhaseKernel_P2 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return 2 * BLOCK_SIZE * sizeof(IndexType); + // return 2 * BLOCK_SIZE * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1718,6 +1944,7 @@ struct MoeSortingMultiPhaseKernel_P3 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1749,9 +1976,9 @@ struct MoeSortingMultiPhaseKernel_P3 k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.tokens = h.tokens; k.num_experts = h.num_experts; k.topk_mdiv = mdiv{static_cast(h.topk)}; @@ -1782,9 +2009,6 @@ struct MoeSortingMultiPhaseKernel_P3 const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - int eid = blockIdx.x; int wave_id = threadIdx.x / warpSize; int lane_id = threadIdx.x % warpSize; @@ -1866,6 +2090,495 @@ struct MoeSortingMultiPhaseKernel_P3 } }; +namespace impl { +// we use dynamic LDS size here +CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) +{ + constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 + const index_t expert_cumsum_elem = num_experts_ + 1; + return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int); +} +} // namespace impl + +// token count cumsum +template +struct MoeSortingMultiPhaseKernel_P23 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + struct Kargs + { + const void* p_weights; + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_total_tokens_post_pad; // [1] + void* p_sorted_expert_ids; + + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_moe_buf; + + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv unit_size_mdiv; + mdiv topk_mdiv; + long_index_t moe_buf_bytes; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + + k.p_moe_buf = h.p_moe_buf; + + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + + k.moe_buf_bytes = h.moe_buf_bytes; + + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // use 1 block to cumsum + // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // only use this at host ! + CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) + { + const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts); + const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType); + return max(smem_23, smem_sf); + } + + // reduce single pixel within a wave + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(static_cast(blockIdx.x) >= kargs.num_experts) + { + impl::moe_buf_set_zero_kernel( + reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes, + blockIdx.x - kargs.num_experts); + return; + } + + extern __shared__ char smem[]; + { + IndexType* s = reinterpret_cast(smem); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_total_tokens_post_pad = + reinterpret_cast(kargs.p_total_tokens_post_pad); + IndexType* p_sorted_expert_ids = + reinterpret_cast(kargs.p_sorted_expert_ids); + + const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % warpSize; + + IndexType prev_cumsum_a = 0; + IndexType prev_cumsum_b = 0; + + for(index_t i = 0; i < loops; i++) + { + index_t position = i * BLOCK_SIZE + threadIdx.x; + IndexType a_ = 0; // token count for a expert + IndexType b_ = 0; // mask for a expert + if(position < kargs.num_experts) + { + a_ = p_expert_cumsum[position]; + if constexpr(Problem::LocalExpertMasking) + b_ = p_local_expert_mask[position]; + } + + int blocks_pers_expert = + kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1); + // pad token + int padded_blocks_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1); + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return b_ ? x_ : 0; + } + else + return x_; + }(); + + IndexType cumsum_a = padded_blocks_per_expert; + IndexType cumsum_b = b_; + + // Note: we first cumsum local round, then add previous cumsum + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + } + + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev_a = s[4 + i_w]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + prev_a = wave_id > i_w ? prev_a : 0; // mask out + prev_b = wave_id > i_w ? prev_b : 0; // mask out + cumsum_a += prev_a; + cumsum_b += prev_b; + }); + + // Now let's add previous cumsum + cumsum_a += prev_cumsum_a; + cumsum_b += prev_cumsum_b; + + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[2] = cumsum_a; // store the last cumsum + s[3] = cumsum_b; + } + + IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt + IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt + + __syncthreads(); + prev_cumsum_a = s[2]; + prev_cumsum_b = s[3]; + + if(position < kargs.num_experts) + { + p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor; + } + + { + if(blockIdx.x == 0) + { + if constexpr(Problem::LocalExpertMasking) + { + if(b_) + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = out_1; + } + } + } + else + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = position; + } + } + } + } + } + + if(threadIdx.x == 0) + { + auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; + if(blockIdx.x == 0) + p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad; + } + } + + __syncthreads(); + + { + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* s = reinterpret_cast(smem); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + const WeightType* p_weights = static_cast(kargs.p_weights); + WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); + + int eid = blockIdx.x; + int wave_id = threadIdx.x / warpSize; + int lane_id = threadIdx.x % warpSize; + int e_start = p_expert_cumsum_smem[eid]; + int e_end = p_expert_cumsum_smem[eid + 1]; + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) + return; + } + + if constexpr(Problem::LocalExpertMasking) + { + int e_mask = p_local_expert_mask[eid]; + if(e_mask == 0) + return; // skip empty expert + } + + // cumsum one by one + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 + using d_t = ext_vector_t; + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; + + for(int i = 0; i < loops; i++) + { + int i_token_pack = i * BLOCK_SIZE + threadIdx.x; + r_t x_v = 0; + if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack) + { + x_v = reinterpret_cast(p_expert_mesh + + eid * kargs.mesh_stride)[i_token_pack]; + } + + r_t x_r; +#if 0 + if constexpr(index_pack != 1) + { + // shuffle, we must have contiguout thread holds contiguout token + __syncthreads(); + reinterpret_cast(s)[threadIdx.x] = x_v; + __syncthreads(); + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + x_r[j] = reinterpret_cast(s)[threadIdx.x + j * BLOCK_SIZE]; + }); + } +#else + x_r = x_v; +#endif + { +#if 0 +#pragma unroll + for(int j = 0; j < index_pack / 2; j++) + { + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE; + index_t x = x_d[j]; + int i_topk = x - 1; // topk of this token + int i_show = x != 0 ? 1 : 0; // has this token or not + int cumsum = i_show; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position = cumsum - i_show; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk]; + } + } +#endif + { + d_t i_topk; + d_t i_show; + // = 0; + int cumsum_store = 0; + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + i_topk[j] = static_cast(x_r[j] - 1); + i_show[j] = static_cast(x_r[j] != 0 ? 1 : 0); + cumsum_store += i_show[j]; + }); + int cumsum = cumsum_store; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + prev_cumsum = s[0]; // update the last cumsum + + int position = cumsum - cumsum_store; + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + // int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * + // BLOCK_SIZE; + int i_token = + i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j; + + if(i_show[j]) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk[j]); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]]; + } + position += i_show[j]; + }); + +#if 0 + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2; + index_t x = x_d[j]; + index_t x0 = static_cast(x & 0xffff); + index_t x1 = static_cast(x >> 16); + int i_topk_0 = x0 - 1; // topk of this token + int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not + int i_topk_1 = x1 - 1; // topk of this token + int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not + int cumsum = i_show_0 + i_show_1; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position_0 = cumsum - i_show_0 - i_show_1; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show_0) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_0] = + MOE_SORTING_MOCK_ID(i_token, i_topk_0); +#else + p_sorted_token_ids[e_start + position_0] = i_token; +#endif + p_sorted_weights[e_start + position_0] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0]; + } + + int position_1 = cumsum - i_show_1; + + if(i_show_1) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_1] = + MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1); +#else + p_sorted_token_ids[e_start + position_1] = i_token + 1; +#endif + p_sorted_weights[e_start + position_1] = + p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1]; + } +#endif + } + } + } + + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); +#else + p_sorted_token_ids[i] = tokens; +#endif + p_sorted_weights[i] = static_cast(0.0); + } + } + } +}; + #undef MOE_SORTING_MOCK_ID } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index a98e0d7652..39bc6ca93e 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -50,20 +50,23 @@ struct MoeSortingProblemEx }; template struct MoeSortingProblemMp { // TODO: this kernel only support warp per row using WeightType = remove_cvref_t; + using MeshType = remove_cvref_t; using IndexType = remove_cvref_t; static constexpr index_t SubTokenTile = SubTokenTile_; static constexpr bool LocalExpertMasking = LocalExpertMasking_; static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; - static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4); + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || + SubTokenTile == 8 || SubTokenTile == 16); }; } // namespace ck_tile