Merge commit '2e9428eb63be091b109537e082aa7f0fc05a634d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-29 17:12:15 +00:00
parent 2e7d600076
commit 3cf7343e08

View File

@@ -1574,6 +1574,7 @@ struct MoeSortingMultiPhaseKernel_P0
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
};
@@ -1597,6 +1598,7 @@ struct MoeSortingMultiPhaseKernel_P0
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
@@ -1655,14 +1657,18 @@ struct MoeSortingMultiPhaseKernel_P0
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
if constexpr(Problem::LocalToken)
if(eid < kargs.num_experts)
{
if(static_cast<index_t>(curr_token_id) < tokens)
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
});
}
}