ep/ll: use 1 expert per SM with 32 warps per block

NCCL-EP's LL dispatch/combine kernel uses (numWarpGroups=1,
numWarpsPerGroup=32) when num_experts <= device_num_sms, giving each
SM ownership of a single expert and 32 warps to cooperate on its
recv-side per-(expert, src_rank) work. We were using (3, 10) — 3
experts per SM, 10 warps per (expert, rank) pair — which left a
significant amount of recv-side parallelism on the table because each
warp had to walk ~3x more tokens sequentially.

Switching to (1, 32) for both dispatch and combine matches NCCL-EP's
structure for typical EP sizes (num_experts in {32, 64, 256}) where
num_experts <= 132 SMs.

The static_assert kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup
still holds (9 <= 32) and the wider block also lets the staging loop
process the hidden-dim with one int4 per thread (hidden_bf16_int4=896
fits easily in 992 working threads).
This commit is contained in:
Qinghua Zhou
2026-04-25 00:25:01 +00:00
parent 1600074f09
commit 7d6efee18b

View File

@@ -435,8 +435,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumWarpsPerGroup = 32;
constexpr int kNumWarpGroups = 1;
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
@@ -683,8 +683,8 @@ void combine(void* combined_x,
void* const* peer_rdma_bases,
mscclpp::MemoryChannelDeviceHandle* memory_channel_handles,
bool use_ipc_path) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumWarpsPerGroup = 32;
constexpr int kNumWarpGroups = 1;
constexpr int kNumMaxTopk = 9;
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;