From edccbb3694b86b930a2978b923472b37a95f2aa3 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 23 Dec 2024 10:59:02 +0800 Subject: [PATCH] [CK_TILE] optimize moe-sorting kernel (#1771) * opt moe sorting * remove commented code [ROCm/composable_kernel commit: 3d15f364b367b24ac709ea5687fa2d7d39f07cf9] --- .../13_moe_sorting/moe_sorting_api.cpp | 53 ++-- .../13_moe_sorting/script/smoke_test.sh | 3 +- .../instances/fused_moesorting_api.cpp | 53 ++-- .../fused_moe/kernel/moe_sorting_kernel.hpp | 253 +++++++++++++++--- .../pipeline/moe_sorting_problem.hpp | 13 +- 5 files changed, 292 insertions(+), 83 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 25e99c5306..723fb3f69f 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,18 +3,42 @@ #include "moe_sorting_api.hpp" -#define MOE_SORTING_DISPATCH(unroll_num_) \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - using ms_problem = ck_tile::MoeSortingProblem; \ - using kernel = ck_tile::MoeSortingKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - const auto lds_bytes = kernel::GetSmemSize(a); \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ +#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr ck_tile::index_t expert_tile = expert_tile_; \ + using ms_problem = \ + ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#define MOE_SORTING_DISPATCH(unroll_num_) \ + if(a.num_experts <= 8) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \ + } \ + else if(a.num_experts <= 16) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \ + } \ + else if(a.num_experts <= 32) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \ + } \ + else if(a.num_experts <= 64) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ + } + float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi case(6): { MOE_SORTING_DISPATCH(6); } - case(7): { - MOE_SORTING_DISPATCH(7); - } case(8): { MOE_SORTING_DISPATCH(8); } - case(9): { - MOE_SORTING_DISPATCH(9); - } case(10): { MOE_SORTING_DISPATCH(10); } - case(11): { - MOE_SORTING_DISPATCH(11); - } default: { MOE_SORTING_DISPATCH(4); } diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 1fc5eafcb0..3ff8a7332d 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 -$EXE -t=333 -e=99 -k=13 \ No newline at end of file +$EXE -t=333 -e=99 -k=13 +$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 75aaf86b74..7ca24c5c9a 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,18 +3,42 @@ #include "fused_moesorting.hpp" -#define MOE_SORTING_DISPATCH(unroll_num_) \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - using ms_problem = ck_tile::MoeSortingProblem; \ - using kernel = ck_tile::MoeSortingKernel; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - const auto lds_bytes = kernel::GetSmemSize(a); \ - float ave_time = ck_tile::launch_kernel( \ - s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ +#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr ck_tile::index_t expert_tile = expert_tile_; \ + using ms_problem = \ + ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#define MOE_SORTING_DISPATCH(unroll_num_) \ + if(a.num_experts <= 8) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \ + } \ + else if(a.num_experts <= 16) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \ + } \ + else if(a.num_experts <= 32) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \ + } \ + else if(a.num_experts <= 64) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ + } + float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til case(6): { MOE_SORTING_DISPATCH(6); } - case(7): { - MOE_SORTING_DISPATCH(7); - } case(8): { MOE_SORTING_DISPATCH(8); } - case(9): { - MOE_SORTING_DISPATCH(9); - } case(10): { MOE_SORTING_DISPATCH(10); } - case(11): { - MOE_SORTING_DISPATCH(11); - } default: { MOE_SORTING_DISPATCH(4); } diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index d9e28ceb52..30e68996b6 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -130,7 +130,8 @@ struct MoeSortingKernel CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { const auto blocks = BlockSize(h); - return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t); + // usually num_experts is power of 2, we pad 1 dword here for the row-size + return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); } CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -154,6 +155,75 @@ struct MoeSortingKernel return k; } + // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + template + __device__ inline void wave_cumsum(data_t& thread_data) const + { + // wave_size must be power of 2 + constexpr int row_mask = 0xf; + constexpr int bank_mask = 0xf; + constexpr bool bound_ctrl = true; // ! out-of-bound is zero ! + auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; }; + + if constexpr(wave_size > 1) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x111, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:1 + } + + if constexpr(wave_size > 2) + { + thread_data = reduce_op( + thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x112, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:2 + } + if constexpr(wave_size > 4) + { + thread_data = + reduce_op(thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x114, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:4 + } + if constexpr(wave_size > 8) + { + thread_data = + reduce_op(thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 + } + + if constexpr(wave_size > 16) + { + // now row-0, row-0+row-1, row-1+row-2, row-2+row-3 + int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data)); + v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0; + thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp)); + } + + if constexpr(wave_size > 32) + { + // lane-id 48...63->31 + int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data)); + v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0; + thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp)); + } + } + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const { return row * total_col + col; @@ -187,48 +257,124 @@ struct MoeSortingKernel index_t* shared_mem = reinterpret_cast(smem); index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) - index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) + index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) + for(int i = 0; i < num_experts; ++i) { - tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; + tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; } + #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; + ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; } __syncthreads(); +#if 1 if(tid < num_experts) { - tokens_cnts[calc_index(num_experts, 0, tid)] = 0; - for(int i = 1; i <= static_cast(blockDim.x); ++i) + tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; + index_t local_c[8]; + index_t prev_c = 0; + // TODO: manually unroll. pragma unroll does not work well when we have dependency + for(int i = 1; i <= static_cast(blockDim.x); i+= 8) { - tokens_cnts[calc_index(num_experts, i, tid)] += - tokens_cnts[calc_index(num_experts, i - 1, tid)]; + local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; + local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; + local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; + local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; + local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; + local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; + local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; + local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; + + local_c[0] += prev_c; + local_c[1] += local_c[0]; + local_c[2] += local_c[1]; + local_c[3] += local_c[2]; + local_c[4] += local_c[3]; + local_c[5] += local_c[4]; + local_c[6] += local_c[5]; + local_c[7] += local_c[6]; + prev_c = local_c[7]; + + tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; + tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; + tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; + tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; + tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; + tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; + tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; + tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; + } + } +#else + // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic + { + if(tid < num_experts) + tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; + for(int i = 0; i < num_experts; i+=8) { + index_t local_c[8]; + #pragma unroll + for(int j = 0; j < 8; j++) { + local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; + } + + #pragma unroll + for(int j = 0; j < 8; j++) { + wave_cumsum(local_c[j]); + } + + #pragma unroll + for(int j = 0; j < 8; j++) { + tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; + } + } + } +#endif + + __syncthreads(); + if constexpr (Problem::ExpertTile == 0) { + if(tid == 0) + { + cumsum[0] = 0; + for(int i = 1; i <= num_experts; ++i) + { + auto current_units = [&]() { + index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; + index_t y_ = unit_size_mdiv.div(x_); + return max(y_, 1) * unit_size_mdiv.divisor; + }(); + cumsum[i] = cumsum[i - 1] + current_units; + } + *p_total_tokens_post_pad = cumsum[num_experts]; + } + } else { + // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) + // for simplicity, not check experts here. + int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); + int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + int local_cumsum = padded_tokens_per_expert; + wave_cumsum(local_cumsum); + + if(tid == (num_experts - 1)) { + cumsum[0] = 0; + *p_total_tokens_post_pad = local_cumsum; + } + if(tid < num_experts) { + cumsum[tid + 1] = local_cumsum; } } - // __syncthreads(); - if(tid == 0) - { - cumsum[0] = 0; - for(int i = 1; i <= num_experts; ++i) - { - auto current_units = [&]() { - index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] + - unit_size_mdiv.divisor - 1; - index_t y_ = unit_size_mdiv.div(x_); - return max(y_, 1) * unit_size_mdiv.divisor; - }(); - cumsum[i] = cumsum[i - 1] + current_units; - } - *p_total_tokens_post_pad = cumsum[num_experts]; - } __syncthreads(); if(tid < num_experts) { - for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor) + int e_start = cumsum[tid]; + int e_end = cumsum[tid + 1]; + for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) { p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; } @@ -238,8 +384,8 @@ struct MoeSortingKernel for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { index_t expert_id = topk_id[i]; - index_t rank_post_pad = - tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; + index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; + index_t rank_post_pad = local_cnt + cumsum[expert_id]; #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID uint32_t curr_token_id, curr_topk_id; topk_mdiv.divmod(i, curr_token_id, curr_topk_id); @@ -247,27 +393,54 @@ struct MoeSortingKernel #else p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); #endif - p_sorted_weights[rank_post_pad] = weights[i]; - ++tokens_cnts[calc_index(num_experts, tid, expert_id)]; + p_sorted_weights[rank_post_pad] = weights[i]; + tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; } - const index_t prefill_token = topk_mdiv.div(numel); - if(tid < num_experts) - { - index_t expert_offset = - cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)]; - while(expert_offset < cumsum[tid + 1]) + if constexpr (Problem::ExpertTile == 0) { + const index_t prefill_token = topk_mdiv.div(numel); + if(tid < num_experts) { + index_t expert_offset = + cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + index_t expert_end = cumsum[tid + 1]; + while(expert_offset < expert_end) + { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID - p_sorted_token_ids[expert_offset] = - MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor); + p_sorted_token_ids[expert_offset] = + MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor); #else - p_sorted_token_ids[expert_offset] = prefill_token; + p_sorted_token_ids[expert_offset] = prefill_token; #endif - p_sorted_weights[expert_offset] = static_cast(0.0); - expert_offset++; + p_sorted_weights[expert_offset] = static_cast(0.0); + expert_offset++; + } } } + else { + const index_t prefill_token = topk_mdiv.div(numel); + // TODO: only support expert-tile like 8, 16, 32 + static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; + { + index_t eid = tid / experts_per_wave; + index_t expert_offset = + cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; + index_t expert_end = cumsum[eid + 1]; + if(eid < num_experts) { + while(expert_offset < expert_end) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[expert_offset] = + MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor); +#else + p_sorted_token_ids[expert_offset] = prefill_token; +#endif + p_sorted_weights[expert_offset] = static_cast(0.0); + expert_offset+=experts_per_wave; + } + } + } + } } CK_TILE_DEVICE void operator()(Kargs kargs) const diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp index adde59e356..50005c4402 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp @@ -9,15 +9,20 @@ namespace ck_tile { -template +template struct MoeSortingProblem { // TODO: this kernel only support warp per row using WeightType = remove_cvref_t; using IndexType = remove_cvref_t; - static constexpr index_t WarpSize = get_warp_size(); - static constexpr index_t WarpsPerBlock = 1; - static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_; + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = + InternalLoadUnroll_; // TODO: need better design(like tile size) + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out }; } // namespace ck_tile