diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index c4faa35e33..f00d948f25 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -152,6 +152,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) if(local_expert_masking) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); + // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr + ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts); + ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); + + if(workspace_size != 0) + moe_sorting_ws.SetZero(); // note, clear here!!!! + moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), @@ -163,6 +170,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, num_experts, @@ -174,13 +182,68 @@ bool test_moe_sorting(ck_tile::ArgParser args) /* log_level = */ (kname ? 1 : 0), warmup, repeat}; + auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", + // auto ms = moe_sorting_mp(trait, karg, sc); + +#if 0 + { + ck_tile::HostTensor ws_host({workspace_size}, {1}); + moe_sorting_ws.FromDevice(ws_host.data()); + + int * p_mesh = reinterpret_cast(ws_host.data()); + ck_tile::index_t row_size = ck_tile::impl::moe_sorting_mp_mesh_stride(tokens); + + std::cout << "topk_ids:" << std::endl; + + int * p_topk_ids = reinterpret_cast(topk_ids_host.data()); + for(int i_token = 0; i_token < tokens; i_token++) { + printf("[t:%2d]", i_token); + for(int i_topk = 0; i_topk < topk; i_topk++) { + printf("%d, ",p_topk_ids[i_token * topk + i_topk] ); + } + printf("\n"); + } + printf("----------------\n"); + + std::vector l_cumsum (num_experts + 1, 0); + for(int i_expert = 0; i_expert < num_experts; i_expert++ ) { + printf("[e:%2d]", i_expert); + int e_cnt = 0; + for(int i_token = 0; i_token < tokens; i_token++) { + auto v_mesh = p_mesh[i_expert * row_size + i_token]; + e_cnt += v_mesh != 0 ? 1 : 0; + printf("%d, ", v_mesh); + } + int e_cnt_unit = (e_cnt + unit_size - 1) / unit_size; + printf("[%d/%d]", e_cnt, e_cnt_unit); + printf("\n"); + l_cumsum[i_expert + 1] = l_cumsum[i_expert] + e_cnt_unit; + } + + printf("----------------\n"); + printf("cumsum:\n"); + for(int i_cc= 0; i_cc < num_experts + 1; i_cc++) { + printf("%2d, ", l_cumsum[i_cc]); + } + printf("\n"); + printf("----------------\n"); + + int * p_cumsum = p_mesh + ck_tile::impl::moe_sorting_mp_mesh_elem(tokens, num_experts); + for(int i_expert = 0; i_expert < num_experts + 1; i_expert++ ) { + printf("%2d(%d), ",p_cumsum[i_expert], p_cumsum[i_expert] / unit_size); + } + printf("\n"); + } +#endif + + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk); + topk, + workspace_size != 0 ? 1 : 0); if(local_expert_masking) { @@ -224,28 +287,41 @@ bool test_moe_sorting(ck_tile::ArgParser args) num_experts, unit_size, local_expert_masking); - rtn &= ck_tile::check_err( - sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); - rtn &= ck_tile::check_err(sorted_weights_host, - sorted_weights_ref, - std::string("OUT Error: Incorrect w!"), - 1e-6, - 1e-6); - rtn &= ck_tile::check_err(sorted_expert_ids_host, - sorted_expert_ids_ref, - std::string("OUT Error: Incorrect eid!"), - 1e-6, - 1e-6); + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); + if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]) + { + size_t slen = ref_total_tokens_post_pad; + rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}), + sorted_ids_ref.slice({0}, {slen}), + std::string("OUT Error: Incorrect ids!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}), + sorted_weights_ref.slice({0}, {slen}), + std::string("OUT Error: Incorrect w!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}), + sorted_expert_ids_ref.slice({0}, {slen / unit_size}), + std::string("OUT Error: Incorrect eid!"), + 1e-6, + 1e-6); + } + else + { + printf("(token size not equal!!)"); + rtn = false; + } + if(moe_buf_size) { ck_tile::HostTensor moe_buf_ref({moe_buf_size}); rtn &= ck_tile::check_err( moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } - rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; - printf("total_tokens_post_pad:%d(%d), ", - ref_total_tokens_post_pad, - sorted_id_cnt_host.mData[0]); + // rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; } printf("valid:%s", rtn ? "y" : "n"); diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index abff24a669..109ec1b157 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -153,18 +153,106 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - using index_t = ck_tile::index_t; - using ms_weight_type = float; - auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); - auto sub_token_ = r_ - 2; - r_ = (r_ - 2) / 8; - bool is_sub_token_onshot = a.tokens <= sub_token_; + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) + { + return moe_sorting_mp(t, a, s); + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts); + auto row_ = sub_token_ / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; - (void)c_; - MOE_SORTING_DISPATCH_EMASK_(r_); + MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); #endif } return -1; } + +#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + using ms_index_t = ck_tile::index_t; + using ms_weight_type = float; + + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(1, true), + MOE_SORTING_MP_1(1, true), + MOE_SORTING_MP_2(1, true), + MOE_SORTING_MP_3(1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(1, false), + MOE_SORTING_MP_1(1, false), + MOE_SORTING_MP_2(1, false), + MOE_SORTING_MP_3(1, false)); + return ave_time; + } + } + return -1; +} + +int moe_sorting_get_workspace_size(int tokens, int num_experts) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts); +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 5bda4d368a..b47ae9013b 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -18,4 +18,10 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs { }; +// use below API before call moe_sorting() to indicate if need workspace or not +// if return non zero, means need workspace, you need to allocate a GPU buffer +// and set to moe_sorting_args.p_ws +// NOTE: workspace size are required to clear zero before use the API +int moe_sorting_get_workspace_size(int tokens, int num_experts); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); +float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index 1f2246fa4a..b354d1d347 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -17,6 +17,9 @@ struct fused_moe_args const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP void* o_ptr; // [m, k], output token (no need to do zeroing) + void* ws_ptr; // size is moe_sorting_get_workspace_size() + // if return zero, then could be nullptr + // must be cleard before use const void* topk_ids_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk] diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index cf9ff2edba..466420f066 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; a.num_tokens, // index_t tokens; a.block_m, // index_t unit_size; a.num_experts, // index_t num_experts; diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 95adcd684b..cb93ce8907 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem num_sorted_tiles_buf( num_sorted_tiles_host.get_element_space_size_in_bytes()); + // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr + ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); + ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); + if(workspace_size != 0) + moe_sorting_ws.SetZero(); // note, clear here!!!! + fused_moe_traits traits{prec_i, prec_w, prec_o, @@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser) local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() : nullptr, o_buf.GetDeviceBuffer(), + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, topk_ids_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(), diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 340f6cb9e5..a1410d1f4f 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -101,7 +101,7 @@ namespace ck_tile { // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) -CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_) +CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_experts_) { /* num_experts + 1 * +--------------------------------------+ @@ -132,7 +132,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu // round to sub_unroll multipl int r_for_sub_token = r - cumsum_bufs; - r_for_sub_token = min(r_for_sub_token, num_tokens_); + r_for_sub_token = min(r_for_sub_token, tokens_); r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll; r_for_sub_token = max(r_for_sub_token, 1); @@ -148,7 +148,6 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most mask_ = ~mask_; - //printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout); r_for_sub_token = (r_unroll_ & mask_) * sub_unroll; } @@ -161,11 +160,17 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int nu return r_for_sub_token + cumsum_bufs; }(); - // printf("r:%d, c:%d\n", smem_rows, smem_cols); - return ck_tile::make_tuple(smem_rows, smem_cols); } +CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_) +{ + auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_); + auto sub_token_ = r_ - 2; + (void) c_; + return sub_token_; +} + struct MoeSortingHostArgs { const void* p_topk_ids; // [token, topk] @@ -180,6 +185,9 @@ struct MoeSortingHostArgs // we fused the setzero of output of fused-moe buffer // set this pointer to nullptr will skip this operation void* p_moe_buf; + void* p_ws; // size is moe_sorting_get_workspace_size() + // if return zero, then could be nullptr + // must be cleard before use index_t tokens; index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; @@ -1046,6 +1054,812 @@ struct MoeSortingKernel } }; +namespace impl { + +// [expert, padded_tokens] +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) +{ + constexpr index_t chunk = 32; + return (tokens + chunk - 1) / chunk * chunk; +}; + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts) +{ + index_t row_size = moe_sorting_mp_mesh_stride(tokens); + return num_experts * row_size; +}; + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts) +{ + constexpr index_t chunk = 32; + index_t row_size = num_experts + 1; + return (row_size + chunk - 1) / chunk * chunk; +}; + +template +CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number = {}) +{ + // constexpr int wave_size = 64; + // constexpr int reduce_stage = 6; // 1<<6=64 + // clang-format off + constexpr int reduce_stage = [](){ + if constexpr(wave_size_ == 2) return 1; + else if constexpr(wave_size_ == 4) return 2; + else if constexpr(wave_size_ == 8) return 3; + else if constexpr(wave_size_ == 16) return 4; + else if constexpr(wave_size_ == 32) return 5; + else if constexpr(wave_size_ == 64) return 6; + else return 0; + }(); + // clang-format on + T v_local = local; +#pragma unroll reduce_stage + for(int i_stage = 0; i_stage < reduce_stage; i_stage++) + { + int src_lane = __lane_id() ^ (1 << i_stage); + int32_t v_remote_tmp = + __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(v_local)); + T v_remote = bit_cast(v_remote_tmp); + v_local = reduce_f(v_local, v_remote); + } + return v_local; +} + +// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] +// NOTE: wave_size need at least be 16!! dpp 16 is one row +template +CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) +{ + // wave_size must be power of 2 + constexpr int row_mask = 0xf; + constexpr int bank_mask = 0xf; + constexpr bool bound_ctrl = true; // ! out-of-bound is zero ! + auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; }; + + if constexpr(wave_size > 1) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x111, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:1 + } + + if constexpr(wave_size > 2) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x112, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:2 + } + if constexpr(wave_size > 4) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x114, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:4 + } + if constexpr(wave_size == 8) + { + + // wave-size=8 need one extra shift + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 +#if 0 + constexpr int bank_mask_0_7 = 0b1100; + auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; + thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, + __builtin_amdgcn_update_dpp(0, /* old value */ + __builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask_0_7, + bound_ctrl))// row_newbcast:7 + ); +#else + data_t xxx = + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask, + bound_ctrl)); // row_newbcast:7 + + data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; + thread_data = thread_data - yyy; +#endif + } + if constexpr(wave_size > 8) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 + } + + if constexpr(wave_size > 16) + { + // now row-0, row-0+row-1, row-1+row-2, row-2+row-3 + int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, + __builtin_bit_cast(int, thread_data)); + v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0; + thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp)); + } + + if constexpr(wave_size > 32) + { + // lane-id 48...63->31 + int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, + __builtin_bit_cast(int, thread_data)); + v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0; + thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp)); + } +} + +template +CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid) +{ + // const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; + index_t offset = gid * BLOCK_SIZE + threadIdx.x; + if(offset < buf_bytes / 16) + { + buf[offset] = uint8x16_t{0}; + } +} + +} // namespace impl + +// prefer to run mp kernel if is not oneshot +CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) +{ + auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_); + bool is_sub_token_onshot = tokens_ <= sub_token_; + return is_sub_token_onshot; +} + +// return size in byte +CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_) +{ + index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) + + impl::moe_sorting_mp_cumsum_elem(num_experts_); + return elem * sizeof(index_t); +} + +// return size in byte +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_) +{ +#if 1 + if(moe_sorting_is_oneshot(tokens_, num_experts_)) + { + return 0; + } + else + { + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + } +#else + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); +#endif +} + +// below kernel is multi-phase implementation for large token and/or expert case + +// write into a buffer to record the token cnt +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float +// number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +/* + +p_expert_mesh: + t0 t1 t2 t3 t4 r5 + +--+--+--+--+--+--+ +e0 | 1| | | | | | +e1 | | | 1| 1| 1| | +e2 | | 1| | 1| | | +e3 | 1| 1| 1| 1| 1| | +e4 | | | | | | | +e5 | 1| 1| 1| | | 1| + + +p_expert_cumsum: + | 1| 3| 2| 5| 0| 4| + e0 e1 e2 e3 e4 e5 + +p_expert_cumsum(with M_a pad, and skip zero tokens): + | 4| 4| 4| 8| 0| 4| + e0 e1 e2 e3 e4 e5 + +p_expert_cumsum + | 0| 4| 8|12|20|20|24| + +local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4) + +p_m_cumsum + | 0| 1| 1| 2| 3| 3| 4| + +*/ + +// count topk_id into mesh +template +struct MoeSortingMultiPhaseKernel_P0 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + void* p_expert_mesh; // [expert, tokens] + index_t tokens; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_expert_mesh = h.p_ws; + k.tokens = h.tokens; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using topk_id_t = ext_vector_t; + + static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || + Problem::SubTokenTile == 4); + + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + +#pragma unroll Problem::SubTokenTile + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x) + { + auto x = p_topk_ids[i]; + static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { + IndexType eid = x[j.value]; // ext_vector_type must use int to [] + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + }); + } + } +}; + +// cnt total tokens for a expert +template +struct MoeSortingMultiPhaseKernel_P1 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + struct Kargs + { + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; + index_t mesh_stride; // mesh_stride for p_expert_mesh + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = + reinterpret_cast(reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() + { + return BLOCK_SIZE / warpSize * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + __shared__ char smem[GetSmemSize()]; + + int eid = blockIdx.x; + + constexpr index_t index_pack = 4; // always packed + using r_t = ext_vector_t; // always use int32x4 + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + + static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || + Problem::SubTokenTile == 4); + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + + if constexpr(Problem::LocalExpertMasking) + { + IndexType mask = p_local_expert_mask[eid]; + if(mask == 0) + return; // skip + } + + index_t cnt = 0; // per-wave cnt + for(int i = 0; i < loops; i++) + { + int position = i * BLOCK_SIZE + threadIdx.x; + r_t v{0}; + if(position < (kargs.mesh_stride / index_pack)) + v = p_expert_mesh[position]; + index_t local_sum = 0; + static_for<0, index_pack, 1>{}( + [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; }); + cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); + } + + index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / warpSize; + + // reduce cross wave + IndexType* s = reinterpret_cast(smem); + if(lane_id == 0) + { + s[wave_id] = cnt; + } + __syncthreads(); + + if(threadIdx.x == 0) + { + index_t c = 0; + for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + { + c += s[i]; + } + p_expert_cumsum[eid] = c; + } + } +}; + +// token count cumsum +template +struct MoeSortingMultiPhaseKernel_P2 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + struct Kargs + { + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_total_tokens_post_pad; // [1] + void* p_sorted_expert_ids; + void* p_moe_buf; + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv unit_size_mdiv; + index_t moe_buf_bytes; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_local_expert_mask = h.p_local_expert_mask; + // k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = + reinterpret_cast(reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + + k.p_moe_buf = h.p_moe_buf; + + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + + k.moe_buf_bytes = h.moe_buf_bytes; + + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // use 1 block to cumsum + return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() + { + return 2 * BLOCK_SIZE * sizeof(IndexType); + } + + // reduce single pixel within a wave + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(blockIdx.x > 0) + { + impl::moe_buf_set_zero_kernel( + reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes, + blockIdx.x - 1); + return; + } + __shared__ char smem[GetSmemSize()]; + IndexType* s = reinterpret_cast(smem); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + IndexType* p_total_tokens_post_pad = + reinterpret_cast(kargs.p_total_tokens_post_pad); + IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); + + const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % warpSize; + + IndexType prev_cumsum_a = 0; + IndexType prev_cumsum_b = 0; + + for(index_t i = 0; i < loops; i++) + { + index_t position = i * BLOCK_SIZE + threadIdx.x; + IndexType a_ = 0; // token count for a expert + IndexType b_ = 0; // mask for a expert + if(position < kargs.num_experts) + { + a_ = p_expert_cumsum[position]; + if constexpr(Problem::LocalExpertMasking) + b_ = p_local_expert_mask[position]; + } + + int blocks_pers_expert = + kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1); + // pad token + int padded_blocks_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1); + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return b_ ? x_ : 0; + } + else + return x_; + }(); + + IndexType cumsum_a = padded_blocks_per_expert; + IndexType cumsum_b = b_; + + // Note: we first cumsum local round, then add previous cumsum + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + } + + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev_a = s[4 + i_w]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + prev_a = wave_id > i_w ? prev_a : 0; // mask out + prev_b = wave_id > i_w ? prev_b : 0; // mask out + cumsum_a += prev_a; + cumsum_b += prev_b; + }); + + // Now let's add previous cumsum + cumsum_a += prev_cumsum_a; + cumsum_b += prev_cumsum_b; + + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[2] = cumsum_a; // store the last cumsum + s[3] = cumsum_b; + } + + IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt + IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt + + __syncthreads(); + prev_cumsum_a = s[2]; + prev_cumsum_b = s[3]; + + if(position < kargs.num_experts) + { + p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor; + } + + { + if constexpr(Problem::LocalExpertMasking) + { + if(b_) + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = out_1; + } + } + } + else + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = position; + } + } + } + } + + if(threadIdx.x == 0) + { + auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; + p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad; + } + } +}; + +template +struct MoeSortingMultiPhaseKernel_P3 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_weights; + const void* p_local_expert_mask; + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_expert_mesh; // [token, expert] + void* p_expert_cumsum; + + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = + reinterpret_cast(reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // in byte + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() + { + return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + __shared__ char smem[GetSmemSize()]; + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* s = reinterpret_cast(smem); + IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + const WeightType* p_weights = static_cast(kargs.p_weights); + WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); + + static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || + Problem::SubTokenTile == 4); + + int eid = blockIdx.x; + int wave_id = threadIdx.x / warpSize; + int lane_id = threadIdx.x % warpSize; + int e_start = p_expert_cumsum[eid]; + int e_end = p_expert_cumsum[eid + 1]; + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) + return; + } + + if constexpr(Problem::LocalExpertMasking) + { + int e_mask = p_local_expert_mask[eid]; + if(e_mask == 0) + return; // skip empty expert + } + + // cumsum one by one + int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; + for(int i = 0; i < loops; i++) + { + int i_token = i * BLOCK_SIZE + threadIdx.x; + IndexType x = 0; + if(i_token < kargs.tokens) + { + x = p_expert_mesh[eid * kargs.mesh_stride + i_token]; + } + int i_topk = x - 1; // topk of this token + int i_show = x != 0 ? 1 : 0; // has this token or not + int cumsum = i_show; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position = cumsum - i_show; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = MOE_SORTING_MOCK_ID(i_token, i_topk); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk]; + } + } + + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); +#else + p_sorted_token_ids[i] = tokens; +#endif + p_sorted_weights[i] = static_cast(0.0); + } + } +}; + #undef MOE_SORTING_MOCK_ID } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index 15effe7118..a98e0d7652 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -49,4 +49,21 @@ struct MoeSortingProblemEx static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out }; +template +struct MoeSortingProblemMp +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t SubTokenTile = SubTokenTile_; + static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4); +}; + } // namespace ck_tile