From 6d0f99f0845b6e5d3c27ea0e0502a82104c75940 Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Fri, 24 Apr 2026 00:02:33 +0000 Subject: [PATCH] ep/ll: gate SM-count grid bump behind IPC path On the PortChannel (cross-node) path the extra blocks don't help: the dispatch recv loop strides tokens per-warp-group (not per-SM), and the additional blocks instead add cooperative-grid sync overhead and increase concurrent host-proxy FIFO traffic. Measured cross-node dispatch regressed from 1013us to 3063us when the unconditional grid bump was active. Keep the scaled grid for the IPC path (intra-node), where combine-recv and dispatch token striding scale with sm_id and the 1.2-1.3x speedup reproduces. --- src/ext/ep/kernels/internode_ll.cu | 41 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/src/ext/ep/kernels/internode_ll.cu b/src/ext/ep/kernels/internode_ll.cu index cf51d033..7598269e 100644 --- a/src/ext/ep/kernels/internode_ll.cu +++ b/src/ext/ep/kernels/internode_ll.cu @@ -441,18 +441,21 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); - // LL dispatch/combine are latency-bound at small problem sizes; grid= - // cell_div(num_experts,3) leaves most SMs idle. Extra blocks beyond the - // expert-owning ones skip the send/count phases (gated by - // `responsible_expert_idx < num_experts`) and only participate in the - // token-striding recv/combine bodies, where more blocks speed up the - // per-token work linearly. Cap at the device SM count because the launch - // is cooperative and `__launch_bounds__(num_warps*32, 1)` allows 1 - // block/SM. - int device_num_sms = 0; - int cur_dev = 0; - cudaGetDevice(&cur_dev); - cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + // LL dispatch/combine are latency-bound at typical problem sizes: for + // num_experts=32 the base grid is cell_div(32,3)=11 blocks, i.e. 8% of a + // 132-SM H100. The recv-side bodies stride tokens by `sm_id`, so extra + // blocks parallelize token work linearly when the transport is cheap. + // + // Only enabled on the IPC path: on the PortChannel path each extra block + // issues more concurrent PUTs into the host proxy FIFO, and the + // cg::this_grid().sync() barrier between phases costs more with a larger + // grid, which empirically regresses cross-node dispatch. + int device_num_sms = num_sms_base; + if (use_ipc_path) { + int cur_dev = 0; + cudaGetDevice(&cur_dev); + cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + } const auto num_sms = std::max(num_sms_base, std::min(device_num_sms, std::max(num_tokens, num_sms_base))); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); @@ -688,11 +691,15 @@ void combine(void* combined_x, const auto num_sms_base = cell_div(num_experts, kNumWarpGroups); // See the comment in `dispatch()` above: combine-recv's per-token loop // strides by `sm_id`, so extra blocks parallelize the weighted reduction - // linearly. Scale grid by `num_combined_tokens` and cap at device SMs. - int device_num_sms = 0; - int cur_dev = 0; - cudaGetDevice(&cur_dev); - cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + // linearly on the IPC path. Keep the baseline grid on the PortChannel + // path to avoid the cooperative-sync / proxy-FIFO overhead that regressed + // cross-node dispatch. + int device_num_sms = num_sms_base; + if (use_ipc_path) { + int cur_dev = 0; + cudaGetDevice(&cur_dev); + cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev); + } const auto num_sms = std::max(num_sms_base, std::min(device_num_sms, std::max(num_combined_tokens, num_sms_base)));