diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index f139081cd4..16fe0ef150 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[]) .insert("e", "8", "number of num_experts") .insert("k", "4", "topk") .insert("unit", "32", "unit_size") +#if MOE_SORTING_FMOE_2D_BUF + .insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf") + .insert( + "moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...") +#else .insert("moe_buf_size", "0", "moe_buf_size") +#endif + .insert("ci", + "1", + "clear workspace inside API or not(if \"0\", require manually clear outside)") + .insert( + "dispatch", + "0", + "dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel") .insert("local_eid", "-1", "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" @@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args) int topk = args.get_int("k"); int seed = args.get_int("seed"); int unit_size = args.get_int("unit"); - int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); - int kname = args.get_int("kname"); - int warmup = args.get_int("warmup"); - int repeat = args.get_int("repeat"); +#if MOE_SORTING_FMOE_2D_BUF + int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim"); + int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes"); +#else + int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); +#endif + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + bool clear_inside = args.get_int("ci") != 0; + int dispatch_policy = args.get_int("dispatch"); int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); - ck_tile::HostTensor sorted_id_cnt_host({1}, {1}); + // for simplicity, below buffer allocate 2 dword + ck_tile::HostTensor sorted_id_cnt_host({2}, {1}); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_host( + {static_cast(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim * + moe_buf_elem_bytes}); + auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#else ck_tile::HostTensor moe_buf_host({moe_buf_size}); + auto moe_buf_bytes = moe_buf_size == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#endif ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#else ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#endif topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); @@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); - if(moe_buf_size > 0) + if(moe_buf_bytes > 0) { moe_buf_dev.ToDevice(moe_buf_host.data()); } @@ -184,29 +219,31 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk); + ck_tile::index_t workspace_size = + moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); - if(workspace_size != 0) + if(workspace_size != 0 && clear_inside == false) moe_sorting_ws.SetZero(); // note, clear here!!!! - moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; + moe_sorting_trait trait{ + index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; - moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), - weights_dev.GetDeviceBuffer(), - local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() - : nullptr, - is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, - sorted_ids_dev.GetDeviceBuffer(), - sorted_weights_dev.GetDeviceBuffer(), - sorted_expert_ids_dev.GetDeviceBuffer(), - sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, - workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, - tokens, - unit_size, - num_experts, - topk, - static_cast(moe_buf_size * sizeof(float))}; + moe_sorting_args karg + { + topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, + num_experts, topk, +#if MOE_SORTING_FMOE_2D_BUF + moe_buf_interm_dim, moe_buf_elem_bytes +#else + static_cast(moe_buf_size * sizeof(float)) +#endif + }; ck_tile::stream_config sc{nullptr, true, @@ -219,7 +256,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) #if 0 { - ck_tile::HostTensor ws_host({workspace_size}, {1}); + ck_tile::HostTensor ws_host({workspace_size}, {1}); moe_sorting_ws.FromDevice(ws_host.data()); int * p_mesh = reinterpret_cast(ws_host.data()); @@ -268,7 +305,12 @@ bool test_moe_sorting(ck_tile::ArgParser args) } #endif - printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens); + printf("[%s|%s|%s|%d]tokens:%d", + index_prec.c_str(), + weight_prec.c_str(), + workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"), + dispatch_policy, + tokens); if(is_local_token) { printf("(%d)", local_tokens); @@ -280,6 +322,19 @@ bool test_moe_sorting(ck_tile::ArgParser args) printf("local_eid:%s, ", args.get_str("local_eid").c_str()); } + if(moe_buf_bytes > 0) + { +#if MOE_SORTING_FMOE_2D_BUF + printf("moe_buf:%lu(%d,%d), ", + static_cast(moe_buf_bytes), + moe_buf_interm_dim, + moe_buf_elem_bytes); +#else + + printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); +#endif + } + if(ms < 0) printf("not supported\n"); else @@ -294,7 +349,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_weights_dev.FromDevice(sorted_weights_host.data()); sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); - if(moe_buf_size > 0) + if(moe_buf_bytes > 0) { moe_buf_dev.FromDevice(moe_buf_host.data()); } @@ -340,6 +395,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6); + // if(is_local_token) + { + auto t_ = is_local_token ? local_tokens : tokens; + bool _f = t_ == sorted_id_cnt_host.mData[1]; + rtn &= _f; + if(!_f) + { + printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]); + } + } } else { @@ -347,9 +412,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) rtn = false; } - if(moe_buf_size) + if(moe_buf_bytes) { +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_ref({moe_buf_bytes}); +#else ck_tile::HostTensor moe_buf_ref({moe_buf_size}); +#endif rtn &= ck_tile::check_err( moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 0899fefcfc..037891353e 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0) { return moe_sorting_mp(t, a, s); } @@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ @@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ @@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = \ ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ @@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi { \ float ave_time = ck_tile::launch_kernel( \ s, \ + maybe_clear_workspace, \ MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ @@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } \ } +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { bool is_local_token = a.p_local_tokens != nullptr; @@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > ck_tile::get_smem_capacity()) { @@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co if(t.local_expert_masking) { float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, MOE_SORTING_MP_0(ms_index_t, 1, true), MOE_SORTING_MP_1(ms_index_t, 1, true), MOE_SORTING_MP_2(ms_index_t, 1, true), @@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co else { float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, MOE_SORTING_MP_0(ms_index_t, 1, false), MOE_SORTING_MP_1(ms_index_t, 1, false), MOE_SORTING_MP_2(ms_index_t, 1, false), @@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co return -1; } -int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk) +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0fe8d81e70..6c6cd0f4fa 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,8 +10,14 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float - bool local_expert_masking; // if mask experts as local expert + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert + bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of + // it inside API) + int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 - + // always use mp kernel NOTE: moe_sorting_get_workspace_size() need use + // same dispatch_policy value. it will be undefined behavior if ppl using + // different value when get ws and call the kernel }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs @@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs // if return non zero, means need workspace, you need to allocate a GPU buffer // and set to moe_sorting_args.p_ws // NOTE: workspace size are required to clear zero before use the API -int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk); +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 63bc0acceb..2c245f6e7f 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -1,7 +1,9 @@ # #!/bin/sh EXE=./build/bin/tile_example_moe_sorting +MOE_BUF="12" +if [ "x$MOE_BUF" = "x1" ] ; then $EXE -t=80 -e=17 -moe_buf_size=16 $EXE -t=111 -e=117 -moe_buf_size=4 $EXE -t=1000 -e=55 -moe_buf_size=1024 @@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1 $EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33 $EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33 $EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940 +else +$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4 +$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4 +$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1 +$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2 +$EXE -t=175 -e=64 -k=8 +$EXE -t=65 -e=8 -k=2 +$EXE -t=1 -e=25 +$EXE -t=31 -e=19 -k=15 +$EXE -t=81 -e=37 -k=7 +$EXE -t=23 -e=1 -k=1 +$EXE -t=127 -e=99 -k=19 +$EXE -t=71 -e=11 -k=11 +$EXE -t=1 -e=1 -k=1 +$EXE -t=99 -e=2 -k=1 +$EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 +$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 +$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1 +$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840 +$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1 +$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840 +$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840 +$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4 +$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145 +$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99 +$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244 +$EXE -t=536 -local_t=345 -e=802 -k=99 +$EXE -t=331 -local_t=39 -e=83 -k=33 +$EXE -t=765 -local_t=654 -e=783 -k=8 +$EXE -t=23 -local_t=9 -e=1 -k=1 +$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33 +$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33 +$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1 + +fi diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 27274878a2..78f664a671 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -6,7 +6,8 @@ int fused_moe_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size( + tokens, num_experts, topk, 0 /*dispatch policy*/); } float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) @@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf }(); auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; - auto a0 = fused_moesorting_args{ - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.local_expert_mask_ptr, // const void* p_local_expert_mask; - a.local_tokens, - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.ws_ptr, // void* p_ws; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; - static_cast(a.num_tokens) * a.stride_token * - o_data_bytes // index_t moe_buf_bytes; + auto a0 = fused_moesorting_args + { + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.local_tokens, + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; +#if MOE_SORTING_FMOE_2D_BUF + a.stride_token, o_data_bytes, +#else + static_cast(a.num_tokens) * + a.stride_token* o_data_bytes // index_t moe_buf_bytes; +#endif }; auto t1 = fused_moegemm_traits{t.prec_i, diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index f745284f3e..83454a3969 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t, int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); + return ck_tile::moe_sorting_get_workspace_size( + tokens, num_experts, topk, 0 /*dispatch policy*/); } diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index d9950426a2..35f24c1155 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr ck_tile::index_t workspace_size = - ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); + ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 5da675ae42..db85fae643 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -23,6 +23,11 @@ namespace ck_tile { #define MOE_SORTING_FUSE_MP_01 0 #endif +// weather use 2d buffer indexing for fmoe ws or 1d +#ifndef MOE_SORTING_FMOE_2D_BUF +#define MOE_SORTING_FMOE_2D_BUF 1 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -171,7 +176,7 @@ struct MoeSortingHostArgs void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; - void* p_total_tokens_post_pad; + void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens) // we fused the setzero of output of fused-moe buffer // set this pointer to nullptr will skip this operation void* p_moe_buf; @@ -182,7 +187,18 @@ struct MoeSortingHostArgs index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; index_t topk; +#if MOE_SORTING_FMOE_2D_BUF + // NOTE: + // moe_buf_* is a 2d ws buffer used for the following fmoe kernel + // arranged as row*col, where row=tokens(or local_token), col=interm_dim + // we fuse this clearing inside sorting kernel + // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe) + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else long_index_t moe_buf_bytes; // byte size of p_moe_buf +#endif + }; template @@ -197,6 +213,9 @@ struct MoeSortingKernel using Hargs = MoeSortingHostArgs; + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + struct Kargs { const void* p_topk_ids; @@ -210,8 +229,12 @@ struct MoeSortingKernel void* p_moe_buf; index_t tokens; index_t num_experts; +#if MOE_SORTING_FMOE_2D_BUF + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else long_index_t moe_buf_bytes; - +#endif index_t tokens_per_thread; index_t smem_rows; mdiv unit_size_mdiv; @@ -220,10 +243,27 @@ struct MoeSortingKernel // mdiv sub_tokens_mdiv; }; + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + (void)h; + return get_num_cu() * OCCUPANCY; +#else // TODO: assume num-experts not too much return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) @@ -263,7 +303,12 @@ struct MoeSortingKernel k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; k.tokens = h.tokens; k.num_experts = h.num_experts; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif const auto blocks = BlockSize(h); // NOTE: tokens could from p_local_tokens, so here this variable is useless @@ -431,6 +476,24 @@ struct MoeSortingKernel } } + CK_TILE_DEVICE void + moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const + { + const long_index_t total_pixels = static_cast(row) * col; + const long_index_t total_bytes = total_pixels * elem_bytes; + const long_index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_buf = reinterpret_cast(buf); + auto zero_ = vector_type{0}; + + for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems; + i += (gridDim.x - 1) * BLOCK_SIZE) + { + p_buf[i] = zero_; + } + } + CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, const WeightType* __restrict__ weights, index_t* p_sorted_token_ids, @@ -863,7 +926,8 @@ struct MoeSortingKernel } if((lid + i_e_ - get_warp_size()) == (num_experts - 1)) { - *p_total_tokens_post_pad = local_cumsum_; + *p_total_tokens_post_pad = local_cumsum_; + p_total_tokens_post_pad[1] = tokens; } } __syncthreads(); @@ -1005,20 +1069,6 @@ struct MoeSortingKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if(blockIdx.x > 0) - { - if(kargs.p_moe_buf) - { - moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), - kargs.moe_buf_bytes); - } - return; - } - const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; - extern __shared__ char smem[]; - -#if MOE_SORTING_USE_EX_KERNEL - (void)numel; index_t tokens_ = [&]() { if constexpr(Problem::LocalToken) { @@ -1029,6 +1079,25 @@ struct MoeSortingKernel return kargs.tokens; } }(); + + if(blockIdx.x > 0) + { + if(kargs.p_moe_buf) + { +#if MOE_SORTING_FMOE_2D_BUF + moe_buf_set_zero_kernel_2d( + kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes); +#else + moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes); +#endif + } + return; + } + + extern __shared__ char smem[]; + +#if MOE_SORTING_USE_EX_KERNEL return moe_align_block_size_kernel_ex( static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), @@ -1045,6 +1114,7 @@ struct MoeSortingKernel kargs.smem_rows, smem); #else + const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), static_cast(kargs.p_sorted_token_ids), @@ -1066,6 +1136,8 @@ namespace impl { // [expert, padded_tokens] CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) { + // Pad to multiply of 32. This can make sure even if the mesh is in 8bit, + // we can still use dwordx4 load/store constexpr index_t chunk = 32; return (tokens + chunk - 1) / chunk * chunk; }; @@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by } } +template +CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d( + void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks) +{ + const long_index_t total_pixels = static_cast(row) * col; + const long_index_t total_bytes = total_pixels * elem_bytes; + const long_index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_buf = reinterpret_cast(buf); + auto zero_ = vector_type{0}; + + for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE) + { + p_buf[i] = zero_; + } +} + } // namespace impl // TODO: tokens could be from @@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe } // return size in byte -CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_) +// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp +// kernel +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, + int num_experts_, + int topk_, + int dispatch_policy_) { #if 1 - if(moe_sorting_is_oneshot(tokens_, num_experts_)) + // return 0; + if(dispatch_policy_ == 0) { - return 0; + if(moe_sorting_is_oneshot(tokens_, num_experts_)) + { + return 0; + } + else + { + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); + } + } + else if(dispatch_policy_ == 1) + { + return 0; // always use single kernel } else { @@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts #endif } +template +struct MoeSortingClearWorkspaceKernel +{ + using Problem = remove_cvref_t; + static constexpr index_t BLOCK_SIZE = Problem::BlockSize; + static constexpr index_t OCCUPANCY = Problem::Occu; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens + void* p_expert_mesh; // [expert, tokens] + index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens + // used for ws/LDS calculation + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t mesh_byte_size; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_local_tokens = h.p_local_tokens; + k.p_expert_mesh = h.p_ws; + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk); + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + + index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens); + index_t pixels = kargs.num_experts * row_size; + index_t total_bytes = pixels * kargs.mesh_byte_size; + index_t total_elems = total_bytes / 16; // always use dwordx4 + + using vector_type = ext_vector_t; + vector_type* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + auto zero_ = vector_type{0}; + + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems; + i += gridDim.x * BLOCK_SIZE) + { + p_expert_mesh[i] = zero_; + } + } +}; + // below kernel is multi-phase implementation for large token and/or expert case // write into a buffer to record the token cnt @@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0 else return tokens; }(); + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile @@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0 if constexpr(Problem::LocalToken) { if(static_cast(curr_token_id) < tokens) - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; } else - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = - (curr_topk_id + 1) & 0xffff; + p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; }); } } @@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1 struct Kargs { const void* p_local_expert_mask; // [expert] + const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; index_t mesh_stride; // mesh_stride for p_expert_mesh @@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; + k.p_local_tokens = h.p_local_tokens; k.p_expert_mesh = h.p_ws; k.p_expert_cumsum = reinterpret_cast( reinterpret_cast(h.p_ws) + @@ -1511,12 +1721,9 @@ struct MoeSortingMultiPhaseKernel_P1 { __shared__ char smem[GetSmemSize()]; - int eid = blockIdx.x; - + int eid = blockIdx.x; constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 - r_t* p_expert_mesh = reinterpret_cast( - reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); @@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1 auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return 0; // will not use if not LocalToken + } + }(); + + index_t mesh_stride = [&]() { + if constexpr(Problem::LocalToken) + { + return impl::moe_sorting_mp_mesh_stride(tokens); + } + else + { + return kargs.mesh_stride; + } + }(); + + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * mesh_stride); + + int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; if constexpr(Problem::LocalExpertMasking) { @@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1 { int position = i * BLOCK_SIZE + threadIdx.x; r_t v{0}; - if(position < (kargs.mesh_stride / index_pack)) + if(position < (mesh_stride / index_pack)) v = p_expert_mesh[position]; index_t local_sum = 0; static_for<0, index_pack, 1>{}( @@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2 const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] - void* p_total_tokens_post_pad; // [1] + void* p_total_tokens_post_pad; // [2] void* p_sorted_expert_ids; void* p_moe_buf; index_t tokens; @@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif return k; } + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + return dim3(h.num_experts + get_num_cu() * OCCUPANCY); +#else // use 1 block to cumsum return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } @@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2 { if(blockIdx.x > 0) { +#if MOE_SORTING_FMOE_2D_BUF + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + kargs.tokens, + kargs.moe_buf_interm_dim, + kargs.moe_buf_elem_bytes, + blockIdx.x - 1, + gridDim.x - 1); + return; +#else impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - 1); return; +#endif } __shared__ char smem[GetSmemSize()]; IndexType* s = reinterpret_cast(smem); @@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23 const void* p_local_tokens; // [1] void* p_expert_mesh; // [expert, tokens] void* p_expert_cumsum; // [expert + 1] - void* p_total_tokens_post_pad; // [1] + void* p_total_tokens_post_pad; // [2] void* p_sorted_expert_ids; void* p_sorted_token_ids; @@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23 index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv unit_size_mdiv; mdiv topk_mdiv; - long_index_t moe_buf_bytes; +#if MOE_SORTING_FMOE_2D_BUF + // NOTE: + // moe_buf_* is a 2d ws buffer used for the following fmoe kernel + // arranged as row*col, where row=tokens(or local_token), col=interm_dim + // we fuse this clearing inside sorting kernel + // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe) + index_t moe_buf_interm_dim; // p_moe_buf interm_dim + index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.) +#else + long_index_t moe_buf_bytes; // byte size of p_moe_buf +#endif }; CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23 k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; +#if MOE_SORTING_FMOE_2D_BUF + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; +#else k.moe_buf_bytes = h.moe_buf_bytes; +#endif return k; } + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { +#if MOE_SORTING_FMOE_2D_BUF + return dim3(h.num_experts + get_num_cu() * OCCUPANCY); +#else // use 1 block to cumsum // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); +#endif } CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } @@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23 // reduce single pixel within a wave CK_TILE_DEVICE void operator()(Kargs kargs) const { + index_t tokens = [&]() { + if constexpr(Problem::LocalToken) + { + return reinterpret_cast(kargs.p_local_tokens)[0]; + } + else + { + return kargs.tokens; + } + }(); + if(static_cast(blockIdx.x) >= kargs.num_experts) { +#if MOE_SORTING_FMOE_2D_BUF + impl::moe_buf_set_zero_kernel_2d(kargs.p_moe_buf, + tokens, + kargs.moe_buf_interm_dim, + kargs.moe_buf_elem_bytes, + blockIdx.x - kargs.num_experts, + gridDim.x - kargs.num_experts); + return; +#else impl::moe_buf_set_zero_kernel( reinterpret_cast(kargs.p_moe_buf), kargs.moe_buf_bytes, blockIdx.x - kargs.num_experts); return; +#endif } extern __shared__ char smem[]; @@ -2428,13 +2743,15 @@ struct MoeSortingMultiPhaseKernel_P23 { auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; if(blockIdx.x == 0) + { p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_total_tokens_post_pad[1] = tokens; + } p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad; } } __syncthreads(); - { const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); @@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23 return; // skip empty expert } - index_t tokens = [&]() { + index_t mesh_stride = [&]() { if constexpr(Problem::LocalToken) { - return reinterpret_cast(kargs.p_local_tokens)[0]; + return impl::moe_sorting_mp_mesh_stride(tokens); } else { - return kargs.tokens; + return kargs.mesh_stride; } }(); @@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23 constexpr index_t index_pack = Problem::SubTokenTile; // always packed using r_t = ext_vector_t; // always use int32x4 using d_t = ext_vector_t; - int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; for(int i = 0; i < loops; i++) @@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23 r_t x_v = 0; if(i_token_pack < (tokens + index_pack - 1) / index_pack) { - x_v = reinterpret_cast(p_expert_mesh + - eid * kargs.mesh_stride)[i_token_pack]; + x_v = reinterpret_cast(p_expert_mesh + eid * mesh_stride)[i_token_pack]; } r_t x_r; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index 181266d7af..ea218b9c25 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -73,4 +73,12 @@ struct MoeSortingProblemMp SubTokenTile == 8 || SubTokenTile == 16); }; +template +struct MoeSortingClearWorkspaceProblem +{ + static constexpr bool LocalToken = LocalToken_; + static constexpr index_t BlockSize = BlockSize_; + static constexpr index_t Occu = Occu_; +}; + } // namespace ck_tile