diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md index 7b6792dd95..1822ff3a37 100644 --- a/example/ck_tile/13_moe_sorting/README.md +++ b/example/ck_tile/13_moe_sorting/README.md @@ -14,14 +14,24 @@ This will result in an executable `build/bin/tile_example_moe_sorting` ## example ``` args: - -v weather do CPU validation or not (default:1) - -pr_i index data type. (currently only fp32 supported now) (default:int32) - -pr_w output weight data type(currently only fp32 supported now) (default:fp32) - -t number of input tokens (default:32) - -e number of experts (default:8) - -k topk (default:2) - -st_i row stride of input, -1 means same as experts (default:-1) - -seed seed to be used, -1 means random every time (default:-1) - -kname when set to 1 it will print kernel name (default:0) + -v turn CPU validation on (1) or off (0). (default:1) + -pr_i index data type. Only int32 is currently supported. (default:int32) + -pr_w output weight data type. Only fp32 is currently supported. (default:fp32) + -t number of input tokens. (default:128) + If "local_t" presents, this value indicates global concurrency of all ranks. + -local_t Number of local input tokens for curent rank. (default:-1) + This value must be within range "[0, t)", or "-1"(no such feature) + This feature is to simulate EP case where where each rank has different tokens. + Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph. + -e number of num_experts (default:8) + -k topk (default:4) + -unit unit_size (default:32) +-moe_buf_size moe_buf_size (default:0) + -local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1) + please make sure eid is in ascending order! + -seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1) + -kname prints the kernel name when set to 1 (default:0) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) ``` diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index da1c15b86f..f139081cd4 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -18,10 +18,20 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do CPU validation or not") - .insert("pr_i", "int32", "index data type. (currently only int32 supported now)") - .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") - .insert("t", "128", "number of input tokens") + arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).") + .insert("pr_i", "int32", "index data type. Only int32 is currently supported.") + .insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.") + .insert("t", + "128", + "number of input tokens.\n" + "If \"local_t\" presents, this value indicates global concurrency of all ranks.") + .insert( + "local_t", + "-1", + "Number of local input tokens for curent rank.\n" + "This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n" + "This feature is to simulate EP case where where each rank has different tokens.\n" + "Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.") .insert("e", "8", "number of num_experts") .insert("k", "4", "topk") .insert("unit", "32", "unit_size") @@ -30,8 +40,11 @@ auto create_args(int argc, char* argv[]) "-1", "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" "please make sure eid is in ascending order!") - .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("seed", + "-1", + "seed to be used. When set to -1, a random seed will be generated each time " + "invoking this example") + .insert("kname", "0", "prints the kernel name when set to 1") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -70,6 +83,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) std::string index_prec = args.get_str("pr_i"); std::string weight_prec = args.get_str("pr_w"); int tokens = args.get_int("t"); + int local_tokens = args.get_int("local_t"); int num_experts = args.get_int("e"); int topk = args.get_int("k"); int seed = args.get_int("seed"); @@ -95,6 +109,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) return false; } + // if local_tokens == tokens, not local_token, but better avoid this since no meaning for such + // case + bool is_local_token = local_tokens >= 0 && local_tokens < tokens; + + if(local_tokens > tokens) + { + printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens); + return false; + } + bool local_expert_masking = args.get_str("local_eid") != "-1"; auto local_expert_masking_host = [&]() { if(local_expert_masking) @@ -143,6 +167,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) ck_tile::DeviceMem local_expert_masking_dev( local_expert_masking_host.get_element_space_size_in_bytes()); + // used for simulating dynamic_tokens for EP case + ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); + if(is_local_token) + { + local_tokens_dev.ToDevice(&local_tokens); + } + topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); if(moe_buf_size > 0) @@ -164,6 +195,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) weights_dev.GetDeviceBuffer(), local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), @@ -236,13 +268,12 @@ bool test_moe_sorting(ck_tile::ArgParser args) } #endif - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ", - index_prec.c_str(), - weight_prec.c_str(), - tokens, - num_experts, - topk, - workspace_size != 0 ? 1 : 0); + printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens); + if(is_local_token) + { + printf("(%d)", local_tokens); + } + printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0); if(local_expert_masking) { @@ -285,6 +316,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ref_total_tokens_post_pad, num_experts, unit_size, + is_local_token ? local_tokens + : tokens, local_expert_masking); printf("total_tokens_post_pad:%d(%d), ", ref_total_tokens_post_pad, diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 305cf118d2..0899fefcfc 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -33,15 +33,18 @@ #else -#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ +#define MOE_SORTING_DISPATCH_( \ + sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \ constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ constexpr bool sub_token_onshot = sub_token_onshot_; \ constexpr bool local_expert_masking = local_expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -51,32 +54,43 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; -#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ - if(row_ % 8 == 0) \ - { \ - MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ - } \ - else if(row_ % 4 == 0) \ - { \ - MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ - } \ - else if(row_ % 2 == 0) \ - { \ - MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ - } \ - else \ - { \ - MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ +#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \ + row_, sub_token_onshot_, local_expert_masking_, local_token_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \ } -#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ - if(is_sub_token_onshot) \ - { \ - MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ - } \ - else \ - { \ - MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ +#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(is_local_token) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \ } #define MOE_SORTING_DISPATCH_EMASK_(row_) \ @@ -171,6 +185,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi auto row_ = sub_token_ / 8; bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; + bool is_local_token = a.p_local_tokens != nullptr; MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); @@ -179,15 +194,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -195,15 +212,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -211,15 +230,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT -#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -227,15 +248,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -244,15 +267,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi }() #endif -#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -261,28 +286,53 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ - if(t.local_expert_masking) \ - { \ - float ave_time = \ - ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ - MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ - MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ - return ave_time; \ - } \ - else \ - { \ - float ave_time = \ - ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ - MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ - MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ - return ave_time; \ +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ } float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { + bool is_local_token = a.p_local_tokens != nullptr; if(t.weight_type == "fp32" && t.index_type == "int32") { using ms_index_t = ck_tile::index_t; diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index fbfb10822c..63bc0acceb 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -31,4 +31,14 @@ $EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840 $EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840 $EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840 $EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840 -$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840 \ No newline at end of file +$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840 +$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145 +$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99 +$EXE -t=99 -local_t=93 -e=121 -moe_buf_size=10244 +$EXE -t=536 -local_t=345 -e=802 -k=99 +$EXE -t=331 -local_t=39 -e=83 -k=33 +$EXE -t=765 -local_t=654 -e=783 -k=8 +$EXE -t=23 -local_t=9 -e=1 -k=1 +$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33 +$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33 +$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940 diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index 46425384cc..e4c25217fb 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -16,6 +16,7 @@ struct fused_moe_args const void* d_scale_ptr; // [e, 1, k], down scale const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP + const void* local_tokens; // [1] if not nullptr, tokens read from here void* o_ptr; // [m, k], output token (no need to do zeroing) void* ws_ptr; // size is moe_sorting_get_workspace_size() // if return zero, then could be nullptr diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index b3515b1bec..27274878a2 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf a.topk_ids_ptr, // const void* p_topk_ids; a.topk_weight_ptr, // const void* p_weights; a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.local_tokens, a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 0d83c48d02..f745284f3e 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -33,15 +33,18 @@ #else -#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ +#define MOE_SORTING_DISPATCH_( \ + sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \ constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ constexpr bool sub_token_onshot = sub_token_onshot_; \ constexpr bool local_expert_masking = local_expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -51,32 +54,43 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; -#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ - if(row_ % 8 == 0) \ - { \ - MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ - } \ - else if(row_ % 4 == 0) \ - { \ - MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ - } \ - else if(row_ % 2 == 0) \ - { \ - MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ - } \ - else \ - { \ - MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ +#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \ + row_, sub_token_onshot_, local_expert_masking_, local_token_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \ } -#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ - if(is_sub_token_onshot) \ - { \ - MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ - } \ - else \ - { \ - MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ +#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(is_local_token) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \ } #define MOE_SORTING_DISPATCH_EMASK_(row_) \ @@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til auto row_ = sub_token_ / 8; bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; + bool is_local_token = a.p_local_tokens != nullptr; MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); @@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #if MOE_SORTING_SUPPORT_LARGE_EXPERT -#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til }() #endif -#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \ [&]() { \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ - if(t.local_expert_masking) \ - { \ - float ave_time = \ - ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ - MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ - MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ - return ave_time; \ - } \ - else \ - { \ - float ave_time = \ - ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ - MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ - MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ - return ave_time; \ +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ } float fused_moesorting_mp(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { + bool is_local_token = a.p_local_tokens != nullptr; if(t.weight_type == "fp32" && t.index_type == "int32") { using ms_index_t = ck_tile::index_t; @@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t, } return -1; } + +int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index da843891ce..d9950426a2 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -87,7 +87,18 @@ void topid_unique_gen( auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("t", "128", "num input tokens") + arg_parser + .insert("t", + "128", + "number of input tokens.\n" + "If \"local_t\" presents, this value indicates global concurrency of all ranks.") + .insert( + "local_t", + "-1", + "Number of local input tokens for curent rank.\n" + "This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n" + "This feature is to simulate EP case where where each rank has different tokens.\n" + "Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.") .insert("e", "32", "num of experts") .insert("k", "5", "topk") .insert("h", "8192", "hidden_size of this model") @@ -131,6 +142,7 @@ template = 0 && local_tokens < tokens; + + if(local_tokens > tokens) + { + printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens); + return false; + } + auto prec_str = [&]() { auto base_str = prec_i; if(prec_i != prec_w) @@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::string(", st:") + std::to_string(stride); }(); + std::cout << "[" << api_str << "|" << prec_str << "]" + << " t:" << tokens; + + if(is_local_token) + { + std::cout << "(" << local_tokens << ")"; + } + std::cout - << "[" << api_str << "|" << prec_str << "]" - << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str - << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp - << ", act:" + << ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size + << ", interm:" << intermediate_size << ", tp:" << tp << ", act:" << activation // << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 << (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush; @@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! + ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); + if(is_local_token) + { + local_tokens_dev.ToDevice(&local_tokens); + } fused_moe_traits traits{prec_i, prec_w, @@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser) fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, o_buf.GetDeviceBuffer(), workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, topk_ids_buf.GetDeviceBuffer(), @@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.mData[0], experts, block_m, + is_local_token ? local_tokens : tokens, local_expert_masking); if(activation == 0) { @@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.mData[0], experts, block_m, + is_local_token ? local_tokens : tokens, local_expert_masking); // done, preparing GPU buffer @@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem o_buf(o_host); + ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); + if(is_local_token) + { + local_tokens_dev.ToDevice(&local_tokens); + } // manually clear output buffer for atomic o_buf.SetZero(); @@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_buf.GetDeviceBuffer(), hidden_size, intermediate_size / tp, - tokens, + is_local_token ? local_tokens : tokens, experts, topk, stride}; diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 47f0ba576b..1e877b9933 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, index_t& unit_cnt, const index_t experts, const index_t unit_size, + const index_t tokens, bool local_expert_masking, bool skip_experts_with_zero_token = true) { - const index_t num_token = topk_ids.mDesc.get_lengths()[0]; + // note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case + const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0]; const index_t topk = topk_ids.mDesc.get_lengths()[1]; // allocate a temp buffer, and fill the value with [number_token|topk] std::vector> expert_tokens( diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index d3c98d7bca..3e2e100025 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -165,7 +165,8 @@ struct MoeSortingHostArgs const void* p_topk_ids; // [token, topk] const void* p_weights; // [token, topk] - const void* p_local_expert_mask; + const void* p_local_expert_mask; // [experts] + const void* p_local_tokens; // [1] if not nullptr, tokens read from here void* p_sorted_token_ids; void* p_sorted_weights; @@ -177,7 +178,7 @@ struct MoeSortingHostArgs void* p_ws; // size is moe_sorting_get_workspace_size() // if return zero, then could be nullptr // must be cleard before use - index_t tokens; + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; index_t topk; @@ -201,6 +202,7 @@ struct MoeSortingKernel const void* p_topk_ids; const void* p_weights; const void* p_local_expert_mask; + const void* p_local_tokens; // [1] if not nullptr, tokens read from here void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -253,6 +255,7 @@ struct MoeSortingKernel k.p_topk_ids = h.p_topk_ids; k.p_weights = h.p_weights; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -263,9 +266,13 @@ struct MoeSortingKernel k.moe_buf_bytes = h.moe_buf_bytes; const auto blocks = BlockSize(h); + // NOTE: tokens could from p_local_tokens, so here this variable is useless + // hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens + // (indeed we can deprecate moe_align_block_size_kernel) k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; + // NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works) k.smem_rows = [&](){ auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); (void) c_; @@ -1009,8 +1016,19 @@ struct MoeSortingKernel } const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; extern __shared__ char smem[]; + #if MOE_SORTING_USE_EX_KERNEL (void)numel; + index_t tokens_ = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); return moe_align_block_size_kernel_ex( static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), @@ -1020,7 +1038,7 @@ struct MoeSortingKernel static_cast(kargs.p_sorted_expert_ids), static_cast(kargs.p_total_tokens_post_pad), kargs.num_experts, - kargs.tokens, + tokens_, kargs.unit_size_mdiv, kargs.topk_mdiv, kargs.expert_mdiv, @@ -1245,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by } // namespace impl +// TODO: tokens could be from // prefer to run mp kernel if is not oneshot CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) { @@ -1351,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0 struct Kargs { - const void* p_topk_ids; // [tokens, topk] - void* p_expert_mesh; // [expert, tokens] - index_t tokens; + const void* p_topk_ids; // [tokens, topk] + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens + void* p_expert_mesh; // [expert, tokens] + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens + // used for ws/LDS calculation index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv topk_mdiv; }; @@ -1373,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) { Kargs k; - k.p_topk_ids = h.p_topk_ids; - k.p_expert_mesh = h.p_ws; - k.tokens = h.tokens; - k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); - k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.p_topk_ids = h.p_topk_ids; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.tokens = h.tokens; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.topk_mdiv = mdiv{static_cast(h.topk)}; return k; } @@ -1394,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0 const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); - index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + index_t rounded_tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile * + Problem::SubTokenTile; + } + else + return tokens; + }(); + index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; @@ -1405,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0 IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = - (curr_topk_id + 1) & 0xffff; + if constexpr(Problem::LocalToken) + { + if(static_cast(curr_token_id) < tokens) + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; + } + else + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; }); } } @@ -1542,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01 { const void* p_topk_ids; // [tokens, topk] const void* p_local_expert_mask; // [expert] + const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] void* p_expert_sem; // [1] @@ -1569,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01 Kargs k; k.p_topk_ids = h.p_topk_ids; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_expert_mesh = h.p_ws; k.p_expert_cumsum = reinterpret_cast( reinterpret_cast(h.p_ws) + @@ -1580,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01 k.tokens = h.tokens; k.num_experts = h.num_experts; k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); - k.wg_count = WGCounts(h); - k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.wg_count = [&]() { + if constexpr(Problem::LocalToken) + { + return GridSize(h); + } + else + { + return WGCounts(h); + } + }(); + k.topk_mdiv = mdiv{static_cast(h.topk)}; return k; } @@ -1607,13 +1666,46 @@ struct MoeSortingMultiPhaseKernel_P01 CK_TILE_DEVICE void operator()(Kargs kargs) const { workgroup_barrier wb{reinterpret_cast(kargs.p_expert_sem)}; + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + index_t rounded_tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile * + Problem::SubTokenTile; + } + else + return tokens; + }(); + index_t wg_count = [&]() { + if constexpr(Problem::LocalToken) + { + index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile; + index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // no more than grid_size + return min(elem_cnt, kargs.wg_count); + } + else + { + return kargs.wg_count; + } + }(); { using topk_id_t = ext_vector_t; const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); - index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; @@ -1625,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01 uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod( i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + // p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + if constexpr(Problem::LocalToken) + { + if(static_cast(curr_token_id) < tokens) + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; + } + else + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; }); } - if(static_cast(blockIdx.x) < kargs.wg_count) + if(static_cast(blockIdx.x) < wg_count) { wb.inc(); } @@ -1642,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(eid >= kargs.num_experts) return; - wb.wait_lt(kargs.wg_count); + wb.wait_lt(wg_count); for(; eid < kargs.num_experts; eid += gridDim.x) { @@ -1731,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2 struct Kargs { const void* p_local_expert_mask; // [expert] + const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] void* p_total_tokens_post_pad; // [1] @@ -1747,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_expert_cumsum = reinterpret_cast( reinterpret_cast(h.p_ws) + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); @@ -1942,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3 { const void* p_weights; const void* p_local_expert_mask; + const void* p_local_tokens; void* p_sorted_token_ids; void* p_sorted_weights; void* p_expert_mesh; // [token, expert] @@ -1958,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3 Kargs k; k.p_weights = h.p_weights; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_expert_mesh = h.p_ws; @@ -1994,6 +2099,16 @@ struct MoeSortingMultiPhaseKernel_P3 const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); int eid = blockIdx.x; int wave_id = threadIdx.x / WarpSize; int lane_id = threadIdx.x % WarpSize; @@ -2019,7 +2134,7 @@ struct MoeSortingMultiPhaseKernel_P3 { int i_token = i * BLOCK_SIZE + threadIdx.x; IndexType x = 0; - if(i_token < kargs.tokens) + if(i_token < tokens) { x = p_expert_mesh[eid * kargs.mesh_stride + i_token]; } @@ -2066,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3 for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID - p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); #else p_sorted_token_ids[i] = tokens; #endif @@ -2105,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23 { const void* p_weights; const void* p_local_expert_mask; // [expert] + const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] void* p_total_tokens_post_pad; // [1] @@ -2127,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23 Kargs k; k.p_weights = h.p_weights; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_expert_mesh = h.p_ws; k.p_expert_cumsum = reinterpret_cast( reinterpret_cast(h.p_ws) + @@ -2346,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23 return; // skip empty expert } + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + // cumsum one by one constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 @@ -2357,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23 { int i_token_pack = i * BLOCK_SIZE + threadIdx.x; r_t x_v = 0; - if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack) + if(i_token_pack < (tokens + index_pack - 1) / index_pack) { x_v = reinterpret_cast(p_expert_mesh + eid * kargs.mesh_stride)[i_token_pack]; @@ -2554,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23 for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID - p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor); #else p_sorted_token_ids[i] = tokens; #endif diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index 39bc6ca93e..181266d7af 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -31,6 +31,7 @@ template struct MoeSortingProblemEx @@ -44,6 +45,7 @@ struct MoeSortingProblemEx static constexpr index_t SubTokenTile = SubTokenTile_; static constexpr bool SubTokenOneShot = SubTokenOneShot_; static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool LocalToken = LocalToken_; static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8); static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out @@ -54,6 +56,7 @@ template struct MoeSortingProblemMp { @@ -64,6 +67,7 @@ struct MoeSortingProblemMp static constexpr index_t SubTokenTile = SubTokenTile_; static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool LocalToken = LocalToken_; static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8 || SubTokenTile == 16);