diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index dbbcb6130a..c4faa35e33 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -131,8 +131,6 @@ bool test_moe_sorting(ck_tile::ArgParser args) ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); - // std::cout << "topk_id:" << topk_ids_host << std::endl; - // std::cout << "local_expert_masking:" << local_expert_masking_host << std::endl; ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes()); @@ -177,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) warmup, repeat}; auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk, - ms); + topk); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + if(ms < 0) printf("not supported\n"); + else + printf("ms:%f, ", ms); fflush(stdout); if(ms < 0) { @@ -221,19 +226,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking); rtn &= ck_tile::check_err( sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); - // std::cout << "sorted_ids_ref:"< moe_buf_ref({moe_buf_size}); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 659d0223a7..cf2c2e164b 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -22,3 +22,7 @@ $EXE -t=64 -e=455 -k=8 $EXE -t=777 -e=802 -k=99 $EXE -t=4097 -e=906 -k=51 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index ba4f4b6e7d..a8c95b9c38 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,12 +27,12 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/tensor/buffer_view.hpp" diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 62070e7613..47f0ba576b 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -104,7 +104,6 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, for(index_t s = 0; s < expert_slices[e]; s++) { - // out_expert_id[s] = e; out_expert_id[s] = curr_expert_id; unit_cnt++; } diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 84f6ff1b0d..340f6cb9e5 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -351,7 +351,6 @@ struct MoeSortingKernel bound_ctrl)); // row_newbcast:7 data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; - // printf("[%d]eid:%d, thread_data:%d, xxx:%d, yyy:%d (%d)\n", threadIdx.x, threadIdx.x/8, thread_data, xxx, yyy, (__lane_id() / 8) % 2); thread_data = thread_data - yyy; #endif @@ -683,12 +682,9 @@ struct MoeSortingKernel index_t* p_total_tokens_post_pad, const index_t num_experts, const index_t tokens, - // const index_t tokens_per_thread, - // const index_t numel, const mdiv unit_size_mdiv, const mdiv topk_mdiv, const mdiv expert_mdiv, - // const mdiv sub_tokens_mdiv, const index_t smem_rows, void* smem) const { @@ -701,18 +697,12 @@ struct MoeSortingKernel auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; const index_t smem_cols = num_experts + 1; - // const index_t total_smem_tokens_pixel = sub_tokens * num_experts; // no need consider - // padding -#if 0 - simple_smem_indexer smem_tokens{reinterpret_cast(smem), smem_cols}; - simple_smem_indexer smem_cumsum{reinterpret_cast(smem) + sub_tokens * smem_cols}; - simple_smem_indexer smem_cumdup{reinterpret_cast(smem) + sub_tokens * smem_cols + smem_cols}; -#else + simple_smem_indexer smem_cumsum{reinterpret_cast(smem) + 0}; simple_smem_indexer smem_cumdup{reinterpret_cast(smem) + smem_cols}; simple_smem_indexer smem_tokens{reinterpret_cast(smem) + 2 * smem_cols, smem_cols}; -#endif + // #pragma unroll 8 for(int i = tid; i < (sub_tokens * num_experts); i += block_size) { @@ -733,13 +723,8 @@ struct MoeSortingKernel if(i_t < tokens) { - int eid = topk_id[i_t * topk + curr_topk_id]; - // printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, - // curr_topk_id:%d, tokens:%d\n", - // eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens); - // smem_tokens(curr_token_id, eid)++; if constexpr(Problem::SubTokenOneShot) smem_tokens(curr_token_id, eid) = curr_topk_id + 1; else @@ -748,27 +733,6 @@ struct MoeSortingKernel __builtin_amdgcn_s_waitcnt(0xc07f); } __syncthreads(); // make sure different i_token iteration not overlap by different wave - // if(tid == 0) { - // int e0 = smem_tokens(0, 0); - // int e1 = smem_tokens(1, 0); - // int e2 = smem_tokens(2, 0); - // int e3 = smem_tokens(3, 0); - // int e4 = smem_tokens(4, 0); - // int e5 = smem_tokens(5, 0); - // int e6 = smem_tokens(6, 0); - // int e7 = smem_tokens(7, 0); - // printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token, - // e0, - // e1, - // e2, - // e3, - // e4, - // e5, - // e6, - // e7, - // e0+e1+e2+e3+e4+e5+e6+e7 - // ); - // } } // counting @@ -777,36 +741,7 @@ struct MoeSortingKernel smem_cumsum(0) = 0; // smem_cumdup(0) = 0; } -#if 0 - (void)f_sum; - for(int i_e = tid; i_e < num_experts; i_e += block_size) - { - index_t local_c[8]; - index_t cnt = 0; - // TODO: manually unroll. pragma unroll does not work well when we have dependency - for(int i = 0; i < sub_tokens; i += 8) - { - local_c[0] = smem_tokens(i + 0, i_e); - local_c[1] = smem_tokens(i + 1, i_e); - local_c[2] = smem_tokens(i + 2, i_e); - local_c[3] = smem_tokens(i + 3, i_e); - local_c[4] = smem_tokens(i + 4, i_e); - local_c[5] = smem_tokens(i + 5, i_e); - local_c[6] = smem_tokens(i + 6, i_e); - local_c[7] = smem_tokens(i + 7, i_e); - cnt += local_c[0]; - cnt += local_c[1]; - cnt += local_c[2]; - cnt += local_c[3]; - cnt += local_c[4]; - cnt += local_c[5]; - cnt += local_c[6]; - cnt += local_c[7]; - } - smem_cumsum(i_e + 1) = cnt; - } -#else { constexpr int lane_group_sz = 8; int lane_group_id = tid / lane_group_sz; @@ -835,61 +770,11 @@ struct MoeSortingKernel { cnt += wave_reduce(local_c[j], f_sum, number<8>{}); } - - // if constexpr(Problem::SubTokenTile == 2) - // printf("i_e:%d, lane_group_os:%d -> %d, %d\n", - // i_e, lane_group_os, - // local_c[0], - // local_c[1]); -// -// printf("i_e:%d, lane_group_os:%d, %d, %d, %d, %d, %d, %d, %d, %d\n", -// i_e, lane_group_os, -// local_c[0], -// local_c[1], -// local_c[2], -// local_c[3], -// local_c[4], -// local_c[5], -// local_c[6], -// local_c[7]); -#if 0 -#if 1 - cnt += - (i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{}); - cnt += - (i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{}); - cnt += - (i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{}); - cnt += - (i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{}); - cnt += - (i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{}); - cnt += - (i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{}); - cnt += - (i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{}); - cnt += - (i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{}); -#else - // TODO: this rely on LDS OOB behavior, too hardware specific - cnt += wave_reduce(local_c[0], f_sum, number<8>{}); - cnt += wave_reduce(local_c[1], f_sum, number<8>{}); - cnt += wave_reduce(local_c[2], f_sum, number<8>{}); - cnt += wave_reduce(local_c[3], f_sum, number<8>{}); - cnt += wave_reduce(local_c[4], f_sum, number<8>{}); - cnt += wave_reduce(local_c[5], f_sum, number<8>{}); - cnt += wave_reduce(local_c[6], f_sum, number<8>{}); - cnt += wave_reduce(local_c[7], f_sum, number<8>{}); -#endif -#endif } if(lane_group_os == 0) smem_cumsum(i_e + 1) = cnt; - - // printf("i_e:%d, cnt:%d\n", i_e, cnt); } } -#endif if constexpr(Problem::LocalExpertMasking) { @@ -897,43 +782,22 @@ struct MoeSortingKernel for(int i_e = tid; i_e < num_experts; i_e += block_size) { // reuse this buffer - // printf("tid:%d, m:%d\n", tid, local_expert_mask[i_e]); smem_cumdup(i_e + 1) = local_expert_mask[i_e]; } } __syncthreads(); -#if 0 - if(tid == 0) - { - (void)lid; - (void)wid; - for(int i = 1; i <= num_experts; ++i) - { - // printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens); - auto current_units = [&]() { - index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1; - index_t y_ = unit_size_mdiv.div(x_); - return max(y_, 1) * unit_size_mdiv.divisor; - }(); - smem_cumsum(i) = smem_cumsum(i - 1) + current_units; - } - *p_total_tokens_post_pad = smem_cumsum(num_experts); - } - __syncthreads(); -#else + { if(wid == 0) { // NOTE: under this block can never use __syncthreads! int i_e_ = 0; int local_cumsum_ = 0; - // int pre_cumsum_ = 0; for(; i_e_ < num_experts; i_e_ += warpSize) { int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); - // if((i_e_+lid) < num_experts) - int local_cnt = smem_cumsum(i_e_ + lid + 1); + int local_cnt = smem_cumsum(i_e_ + lid + 1); int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); @@ -976,10 +840,7 @@ struct MoeSortingKernel // pre_sumsum has value, which will result int // zero local cumsum(but we want at least padded) wave_cumsum(local_cumsum_); - // printf(" lid:%d(%d), local_cnt:%d,pre_cumsum_:%d, %d--> %d (m:%d)\n", lid, - // i_e_ + - // lid, local_cnt, pre_cumsum_, padded_tokens_per_expert,local_cumsum_ - // ,local_masking); + if((i_e_ + lid) < num_experts) smem_cumsum(i_e_ + lid + 1) = local_cumsum_; @@ -1003,7 +864,6 @@ struct MoeSortingKernel } __syncthreads(); } -#endif for(int i_e = tid; i_e < num_experts; i_e += block_size) { @@ -1020,9 +880,6 @@ struct MoeSortingKernel return i_e; }(); - // printf("i_e:%d, e_start:%d, e_end:%d, expert_id:%d (%d-%d, m:%d)\n", i_e, e_start, - // e_end, expert_id, e_start, e_end, local_expert_mask[i_e]); - smem_cumdup(i_e) = e_start; // duplicate cumsum for later use if constexpr(Problem::SkipExpertsWithZeroTokens) { @@ -1041,7 +898,6 @@ struct MoeSortingKernel p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; } } - // if (tid == 0) smem_cumdup(num_experts) = smem_cumsum(num_experts); // fill the p_sorted_token_ids/p_sorted_weights @@ -1068,40 +924,13 @@ struct MoeSortingKernel int i_t = i_token + curr_token_id; if(i_t < tokens) { - int eid = topk_id[i_t * topk + curr_topk_id]; - // if(eid == 0) { - // printf("@@@ eid:%d, i_t:%d, cur:%d, curr_topk_id:%d\n", eid, i_t, - // curr_token_id, curr_topk_id); printf("## eid:%d,%d\n", i_t, - // curr_topk_id); - //} + int eid = topk_id[i_t * topk + curr_topk_id]; smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1 } } __syncthreads(); } -#if 0 - for(int eid = tid; eid < num_experts; eid += block_size) { - // indeed we can unroll 8x - for(int i_sub_token = 0; i_sub_token < sub_tokens; i_sub_token++) { - auto x = smem_tokens(i_sub_token, eid); - //if (eid == 0) - // printf("@@ eid:%d, pos:%d, i_sub_token:%d, x:%d\n", eid, smem_cumsum(eid), i_sub_token, x); - if(x != 0) { - // now x is topk value - int position = smem_cumsum(eid); -#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID - p_sorted_token_ids[position] = MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1); -#else - p_sorted_token_ids[position] = i_token + i_sub_token; -#endif - p_sorted_weights[position] = weights[(i_token + i_sub_token) * topk + x - 1]; - smem_cumsum(eid) = position + 1; // increase position - } - // __syncthreads(); - } - } -#else { constexpr int lane_group_sz = 8; int lane_group_id = tid / lane_group_sz; @@ -1138,17 +967,12 @@ struct MoeSortingKernel int remote_cnt = __builtin_amdgcn_ds_bpermute( (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt); - // printf("[%d]eid:%d, i_sub_token:%d, position:%d, x:%d, local_cnt:%d(%d), - // remote_cnt:%d\n", - // tid, eid, i_sub_token, position, x, local_cnt, local_cnt_cache, - // remote_cnt); + position += remote_cnt; } smem_cumsum(eid) = position; } } -#endif - // (void) weights; __syncthreads(); } @@ -1157,7 +981,6 @@ struct MoeSortingKernel { int e_start = smem_cumsum(eid); int e_end = smem_cumdup(eid + 1); - // printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end); if constexpr(Problem::SkipExpertsWithZeroTokens) { if(e_start == e_end) // skip zero token expert @@ -1201,8 +1024,6 @@ struct MoeSortingKernel static_cast(kargs.p_total_tokens_post_pad), kargs.num_experts, kargs.tokens, - // kargs.tokens_per_thread, - // numel, kargs.unit_size_mdiv, kargs.topk_mdiv, kargs.expert_mdiv,