hot fix check eid range (#2924)

* hot fix check eid range

* fix clang format

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
carlushuang
2025-09-30 00:38:38 +08:00
committed by GitHub
parent 2b684f0a7d
commit 2e9428eb63

View File

@@ -1574,6 +1574,7 @@ struct MoeSortingMultiPhaseKernel_P0
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
};
@@ -1597,6 +1598,7 @@ struct MoeSortingMultiPhaseKernel_P0
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
@@ -1655,14 +1657,18 @@ struct MoeSortingMultiPhaseKernel_P0
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
if constexpr(Problem::LocalToken)
if(eid < kargs.num_experts)
{
if(static_cast<index_t>(curr_token_id) < tokens)
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
});
}
}