diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 44a1929f7c..775e5ac7bb 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -19,5 +19,5 @@ $EXE -t=99 -e=2 -k=1 $EXE -t=333 -e=99 -k=13 $EXE -t=11 -e=256 -k=5 $EXE -t=64 -e=455 -k=8 -# $EXE -t=777 -e=802 -k=99 # has bug??? +$EXE -t=777 -e=802 -k=99 # has bug??? $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 438a4f4600..9678f0de7d 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -178,7 +178,7 @@ struct MoeSortingKernel return r_for_sub_token + cumsum_bufs; }(); - printf("r:%d, c:%d\n", smem_rows, smem_cols); + // printf("r:%d, c:%d\n", smem_rows, smem_cols); return ck_tile::make_tuple(smem_rows, smem_cols); } @@ -668,26 +668,79 @@ struct MoeSortingKernel for(int i_token = 0; i_token < tokens; i_token += sub_tokens) // int i_token = 0; { +#if 1 + __syncthreads(); // #pragma unroll 8 for(int i = tid; i < (sub_tokens * topk); i += block_size) { uint32_t curr_token_id, curr_topk_id; topk_mdiv.divmod(i, curr_token_id, curr_topk_id); int i_t = i_token + curr_token_id; - // printf("--- tid:%d, i_token:%d, curr_token_id:%d, curr_topk_id:%d, tokens:%d\n", - // tid, i_token, curr_token_id,curr_topk_id, tokens); + + if(i_t < tokens) + { + + int eid = topk_id[i_t * topk + curr_topk_id]; + + // printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, + // curr_topk_id:%d, tokens:%d\n", + // eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens); + smem_tokens(curr_token_id, eid)++; + } + __builtin_amdgcn_s_waitcnt(0xc07f); + // + } + __syncthreads(); + // if(tid == 0) { + // int e0 = smem_tokens(0, 0); + // int e1 = smem_tokens(1, 0); + // int e2 = smem_tokens(2, 0); + // int e3 = smem_tokens(3, 0); + // int e4 = smem_tokens(4, 0); + // int e5 = smem_tokens(5, 0); + // int e6 = smem_tokens(6, 0); + // int e7 = smem_tokens(7, 0); + // printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token, + // e0, + // e1, + // e2, + // e3, + // e4, + // e5, + // e6, + // e7, + // e0+e1+e2+e3+e4+e5+e6+e7 + // ); + // } + +#else + int i = tid; + while(true) + { + __syncthreads(); + if(i >= (sub_tokens * topk)) + break; + uint32_t curr_token_id, curr_topk_id; + topk_mdiv.divmod(i, curr_token_id, curr_topk_id); + int i_t = i_token + curr_token_id; + // printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d, + // tokens:%d\n", + // i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens); if(i_t < tokens) { int eid = topk_id[i_t * topk + curr_topk_id]; smem_tokens(curr_token_id, eid)++; } + + i += block_size; } + __syncthreads(); +#endif } - __syncthreads(); // counting smem_cumsum(0) = 0; -#if 1 +#if 0 (void)f_sum; for(int i_e = tid; i_e < num_experts; i_e += block_size) { @@ -727,13 +780,7 @@ struct MoeSortingKernel { index_t local_c[8]; index_t cnt = 0; - // TODO: manually unroll. pragma unroll does not work well when we have dependency - // for(int i = 0; i < sub_tokens; i+= 8) - // { - // local_c[0] = smem_tokens(i + lane_group_os, i_e); - // int sum_ = wave_reduce(local_c[0], f_sum); - // cnt += sum_; - // } + for(int i = 0; i < sub_tokens; i += 8 * 8) { local_c[0] = smem_tokens(i + 0 * 8 + lane_group_os, i_e); @@ -755,7 +802,25 @@ struct MoeSortingKernel // local_c[5], // local_c[6], // local_c[7]); - +#if 1 + cnt += + (i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{}); + cnt += + (i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{}); + cnt += + (i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{}); + cnt += + (i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{}); + cnt += + (i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{}); + cnt += + (i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{}); + cnt += + (i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{}); + cnt += + (i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{}); +#else + // TODO: this rely on LDS OOB behavior, too hardware specific cnt += wave_reduce(local_c[0], f_sum, number<8>{}); cnt += wave_reduce(local_c[1], f_sum, number<8>{}); cnt += wave_reduce(local_c[2], f_sum, number<8>{}); @@ -764,6 +829,7 @@ struct MoeSortingKernel cnt += wave_reduce(local_c[5], f_sum, number<8>{}); cnt += wave_reduce(local_c[6], f_sum, number<8>{}); cnt += wave_reduce(local_c[7], f_sum, number<8>{}); +#endif } if(lane_group_os == 0) smem_cumsum(i_e + 1) = cnt; @@ -780,7 +846,7 @@ struct MoeSortingKernel (void)wid; for(int i = 1; i <= num_experts; ++i) { - //printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens); + // printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens); auto current_units = [&]() { index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1; index_t y_ = unit_size_mdiv.div(x_);