mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
hot fix check eid range (#2924)
* hot fix check eid range * fix clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -1574,6 +1574,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
@@ -1597,6 +1598,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
@@ -1655,14 +1657,18 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
if constexpr(Problem::LocalToken)
|
||||
if(eid < kargs.num_experts)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user