From 7d6efee18bdd1af98c015dd144d237b21f4fe375 Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Sat, 25 Apr 2026 00:25:01 +0000 Subject: [PATCH] ep/ll: use 1 expert per SM with 32 warps per block MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NCCL-EP's LL dispatch/combine kernel uses (numWarpGroups=1, numWarpsPerGroup=32) when num_experts <= device_num_sms, giving each SM ownership of a single expert and 32 warps to cooperate on its recv-side per-(expert, src_rank) work. We were using (3, 10) — 3 experts per SM, 10 warps per (expert, rank) pair — which left a significant amount of recv-side parallelism on the table because each warp had to walk ~3x more tokens sequentially. Switching to (1, 32) for both dispatch and combine matches NCCL-EP's structure for typical EP sizes (num_experts in {32, 64, 256}) where num_experts <= 132 SMs. The static_assert kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup still holds (9 <= 32) and the wider block also lets the staging loop process the hidden-dim with one int4 per thread (hidden_bf16_int4=896 fits easily in 992 working threads). --- src/ext/ep/kernels/internode_ll.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index 7598269e..bd561f4b 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -435,8 +435,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { constexpr int kNumMaxTopK = 9; - constexpr int kNumWarpsPerGroup = 10; - constexpr int kNumWarpGroups = 3; + constexpr int kNumWarpsPerGroup = 32; + constexpr int kNumWarpGroups = 1; EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; @@ -683,8 +683,8 @@ void combine(void* combined_x, void* const* peer_rdma_bases, mscclpp::MemoryChannelDeviceHandle* memory_channel_handles, bool use_ipc_path) { - constexpr int kNumWarpsPerGroup = 10; - constexpr int kNumWarpGroups = 3; + constexpr int kNumWarpsPerGroup = 32; + constexpr int kNumWarpGroups = 1; constexpr int kNumMaxTopk = 9; const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;