ep/ll: gate SM-count grid bump behind IPC path

On the PortChannel (cross-node) path the extra blocks don't help: the
dispatch recv loop strides tokens per-warp-group (not per-SM), and the
additional blocks instead add cooperative-grid sync overhead and
increase concurrent host-proxy FIFO traffic. Measured cross-node
dispatch regressed from 1013us to 3063us when the unconditional grid
bump was active.

Keep the scaled grid for the IPC path (intra-node), where combine-recv
and dispatch token striding scale with sm_id and the 1.2-1.3x speedup
reproduces.
This commit is contained in:
Qinghua Zhou
2026-04-24 00:02:33 +00:00
parent 85316b1863
commit 6d0f99f084

View File

@@ -441,18 +441,21 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms_base = cell_div(num_experts, kNumWarpGroups);
// LL dispatch/combine are latency-bound at small problem sizes; grid=
// cell_div(num_experts,3) leaves most SMs idle. Extra blocks beyond the
// expert-owning ones skip the send/count phases (gated by
// `responsible_expert_idx < num_experts`) and only participate in the
// token-striding recv/combine bodies, where more blocks speed up the
// per-token work linearly. Cap at the device SM count because the launch
// is cooperative and `__launch_bounds__(num_warps*32, 1)` allows 1
// block/SM.
int device_num_sms = 0;
int cur_dev = 0;
cudaGetDevice(&cur_dev);
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
// LL dispatch/combine are latency-bound at typical problem sizes: for
// num_experts=32 the base grid is cell_div(32,3)=11 blocks, i.e. 8% of a
// 132-SM H100. The recv-side bodies stride tokens by `sm_id`, so extra
// blocks parallelize token work linearly when the transport is cheap.
//
// Only enabled on the IPC path: on the PortChannel path each extra block
// issues more concurrent PUTs into the host proxy FIFO, and the
// cg::this_grid().sync() barrier between phases costs more with a larger
// grid, which empirically regresses cross-node dispatch.
int device_num_sms = num_sms_base;
if (use_ipc_path) {
int cur_dev = 0;
cudaGetDevice(&cur_dev);
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
}
const auto num_sms = std::max<int>(num_sms_base,
std::min<int>(device_num_sms, std::max<int>(num_tokens, num_sms_base)));
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
@@ -688,11 +691,15 @@ void combine(void* combined_x,
const auto num_sms_base = cell_div(num_experts, kNumWarpGroups);
// See the comment in `dispatch()` above: combine-recv's per-token loop
// strides by `sm_id`, so extra blocks parallelize the weighted reduction
// linearly. Scale grid by `num_combined_tokens` and cap at device SMs.
int device_num_sms = 0;
int cur_dev = 0;
cudaGetDevice(&cur_dev);
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
// linearly on the IPC path. Keep the baseline grid on the PortChannel
// path to avoid the cooperative-sync / proxy-FIFO overhead that regressed
// cross-node dispatch.
int device_num_sms = num_sms_base;
if (use_ipc_path) {
int cur_dev = 0;
cudaGetDevice(&cur_dev);
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
}
const auto num_sms = std::max<int>(num_sms_base,
std::min<int>(device_num_sms,
std::max<int>(num_combined_tokens, num_sms_base)));