mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix bug for race condition
This commit is contained in:
@@ -19,5 +19,5 @@ $EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
# $EXE -t=777 -e=802 -k=99 # has bug???
|
||||
$EXE -t=777 -e=802 -k=99 # has bug???
|
||||
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
|
||||
|
||||
@@ -178,7 +178,7 @@ struct MoeSortingKernel
|
||||
return r_for_sub_token + cumsum_bufs;
|
||||
}();
|
||||
|
||||
printf("r:%d, c:%d\n", smem_rows, smem_cols);
|
||||
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
|
||||
|
||||
return ck_tile::make_tuple(smem_rows, smem_cols);
|
||||
}
|
||||
@@ -668,26 +668,79 @@ struct MoeSortingKernel
|
||||
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
|
||||
// int i_token = 0;
|
||||
{
|
||||
#if 1
|
||||
__syncthreads();
|
||||
// #pragma unroll 8
|
||||
for(int i = tid; i < (sub_tokens * topk); i += block_size)
|
||||
{
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
int i_t = i_token + curr_token_id;
|
||||
// printf("--- tid:%d, i_token:%d, curr_token_id:%d, curr_topk_id:%d, tokens:%d\n",
|
||||
// tid, i_token, curr_token_id,curr_topk_id, tokens);
|
||||
|
||||
if(i_t < tokens)
|
||||
{
|
||||
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
|
||||
// printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d,
|
||||
// curr_topk_id:%d, tokens:%d\n",
|
||||
// eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
//
|
||||
}
|
||||
__syncthreads();
|
||||
// if(tid == 0) {
|
||||
// int e0 = smem_tokens(0, 0);
|
||||
// int e1 = smem_tokens(1, 0);
|
||||
// int e2 = smem_tokens(2, 0);
|
||||
// int e3 = smem_tokens(3, 0);
|
||||
// int e4 = smem_tokens(4, 0);
|
||||
// int e5 = smem_tokens(5, 0);
|
||||
// int e6 = smem_tokens(6, 0);
|
||||
// int e7 = smem_tokens(7, 0);
|
||||
// printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token,
|
||||
// e0,
|
||||
// e1,
|
||||
// e2,
|
||||
// e3,
|
||||
// e4,
|
||||
// e5,
|
||||
// e6,
|
||||
// e7,
|
||||
// e0+e1+e2+e3+e4+e5+e6+e7
|
||||
// );
|
||||
// }
|
||||
|
||||
#else
|
||||
int i = tid;
|
||||
while(true)
|
||||
{
|
||||
__syncthreads();
|
||||
if(i >= (sub_tokens * topk))
|
||||
break;
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
int i_t = i_token + curr_token_id;
|
||||
// printf("[%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d, curr_topk_id:%d,
|
||||
// tokens:%d\n",
|
||||
// i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
|
||||
if(i_t < tokens)
|
||||
{
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
|
||||
i += block_size;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// counting
|
||||
smem_cumsum(0) = 0;
|
||||
#if 1
|
||||
#if 0
|
||||
(void)f_sum;
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
@@ -727,13 +780,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
index_t local_c[8];
|
||||
index_t cnt = 0;
|
||||
// TODO: manually unroll. pragma unroll does not work well when we have dependency
|
||||
// for(int i = 0; i < sub_tokens; i+= 8)
|
||||
// {
|
||||
// local_c[0] = smem_tokens(i + lane_group_os, i_e);
|
||||
// int sum_ = wave_reduce(local_c[0], f_sum);
|
||||
// cnt += sum_;
|
||||
// }
|
||||
|
||||
for(int i = 0; i < sub_tokens; i += 8 * 8)
|
||||
{
|
||||
local_c[0] = smem_tokens(i + 0 * 8 + lane_group_os, i_e);
|
||||
@@ -755,7 +802,25 @@ struct MoeSortingKernel
|
||||
// local_c[5],
|
||||
// local_c[6],
|
||||
// local_c[7]);
|
||||
|
||||
#if 1
|
||||
cnt +=
|
||||
(i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{});
|
||||
#else
|
||||
// TODO: this rely on LDS OOB behavior, too hardware specific
|
||||
cnt += wave_reduce(local_c[0], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[1], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[2], f_sum, number<8>{});
|
||||
@@ -764,6 +829,7 @@ struct MoeSortingKernel
|
||||
cnt += wave_reduce(local_c[5], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[6], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[7], f_sum, number<8>{});
|
||||
#endif
|
||||
}
|
||||
if(lane_group_os == 0)
|
||||
smem_cumsum(i_e + 1) = cnt;
|
||||
@@ -780,7 +846,7 @@ struct MoeSortingKernel
|
||||
(void)wid;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
//printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
|
||||
// printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
|
||||
Reference in New Issue
Block a user