From 3cf7343e085cecde790aa52d957d50b35ed7adae Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 29 Sep 2025 17:12:15 +0000 Subject: [PATCH] Merge commit '2e9428eb63be091b109537e082aa7f0fc05a634d' into develop --- .../ops/fused_moe/kernel/moe_sorting_kernel.hpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 28416ec538..42e2fad236 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1574,6 +1574,7 @@ struct MoeSortingMultiPhaseKernel_P0 void* p_expert_mesh; // [expert, tokens] index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens // used for ws/LDS calculation + index_t num_experts; index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv topk_mdiv; }; @@ -1597,6 +1598,7 @@ struct MoeSortingMultiPhaseKernel_P0 k.p_local_tokens = h.p_local_tokens; k.p_expert_mesh = h.p_ws; k.tokens = h.tokens; + k.num_experts = h.num_experts; k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); k.topk_mdiv = mdiv{static_cast(h.topk)}; return k; @@ -1655,14 +1657,18 @@ struct MoeSortingMultiPhaseKernel_P0 IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - if constexpr(Problem::LocalToken) + if(eid < kargs.num_experts) { - if(static_cast(curr_token_id) < tokens) + if constexpr(Problem::LocalToken) + { + if(static_cast(curr_token_id) < tokens) + p_expert_mesh[eid * mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; + } + else p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; } - else - p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff; }); } }