fix bug for race condition

This commit is contained in:
carlushuang
2025-01-27 15:14:37 +08:00
parent 8e61104781
commit e6d40dbebc
2 changed files with 81 additions and 15 deletions

View File

@@ -19,5 +19,5 @@ $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
# $EXE -t=777 -e=802 -k=99 # has bug???
$EXE -t=777 -e=802 -k=99 # has bug???
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144

View File

@@ -178,7 +178,7 @@ struct MoeSortingKernel
return r_for_sub_token + cumsum_bufs;
}();
printf("r:%d, c:%d\n", smem_rows, smem_cols);
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return ck_tile::make_tuple(smem_rows, smem_cols);
}
@@ -668,26 +668,79 @@ struct MoeSortingKernel
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
// int i_token = 0;
{
#if 1
__syncthreads();
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * topk); i += block_size)
{
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
// printf("--- tid:%d, i_token:%d, curr_token_id:%d, curr_topk_id:%d, tokens:%d\n",
// tid, i_token, curr_token_id,curr_topk_id, tokens);
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
// printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d,
// curr_topk_id:%d, tokens:%d\n",
// eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
//
}
__syncthreads();
// if(tid == 0) {
// int e0 = smem_tokens(0, 0);
// int e1 = smem_tokens(1, 0);
// int e2 = smem_tokens(2, 0);
// int e3 = smem_tokens(3, 0);
// int e4 = smem_tokens(4, 0);
// int e5 = smem_tokens(5, 0);
// int e6 = smem_tokens(6, 0);
// int e7 = smem_tokens(7, 0);
// printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token,
// e0,
// e1,
// e2,
// e3,
// e4,
// e5,
// e6,
// e7,
// e0+e1+e2+e3+e4+e5+e6+e7
// );
// }
#else
int i = tid;
while(true)
{
__syncthreads();
if(i >= (sub_tokens * topk))
break;
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
int i_t = i_token + curr_token_id;
// printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d,
// tokens:%d\n",
// i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid)++;
}
i += block_size;
}
__syncthreads();
#endif
}
__syncthreads();
// counting
smem_cumsum(0) = 0;
#if 1
#if 0
(void)f_sum;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
@@ -727,13 +780,7 @@ struct MoeSortingKernel
{
index_t local_c[8];
index_t cnt = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
// for(int i = 0; i < sub_tokens; i+= 8)
// {
// local_c[0] = smem_tokens(i + lane_group_os, i_e);
// int sum_ = wave_reduce(local_c[0], f_sum);
// cnt += sum_;
// }
for(int i = 0; i < sub_tokens; i += 8 * 8)
{
local_c[0] = smem_tokens(i + 0 * 8 + lane_group_os, i_e);
@@ -755,7 +802,25 @@ struct MoeSortingKernel
// local_c[5],
// local_c[6],
// local_c[7]);
#if 1
cnt +=
(i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{});
cnt +=
(i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{});
cnt +=
(i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{});
cnt +=
(i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{});
cnt +=
(i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{});
cnt +=
(i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{});
cnt +=
(i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{});
cnt +=
(i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{});
#else
// TODO: this rely on LDS OOB behavior, too hardware specific
cnt += wave_reduce(local_c[0], f_sum, number<8>{});
cnt += wave_reduce(local_c[1], f_sum, number<8>{});
cnt += wave_reduce(local_c[2], f_sum, number<8>{});
@@ -764,6 +829,7 @@ struct MoeSortingKernel
cnt += wave_reduce(local_c[5], f_sum, number<8>{});
cnt += wave_reduce(local_c[6], f_sum, number<8>{});
cnt += wave_reduce(local_c[7], f_sum, number<8>{});
#endif
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
@@ -780,7 +846,7 @@ struct MoeSortingKernel
(void)wid;
for(int i = 1; i <= num_experts; ++i)
{
//printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
// printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
auto current_units = [&]() {
index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);