mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
ep/ll: gate SM-count grid bump behind IPC path
On the PortChannel (cross-node) path the extra blocks don't help: the dispatch recv loop strides tokens per-warp-group (not per-SM), and the additional blocks instead add cooperative-grid sync overhead and increase concurrent host-proxy FIFO traffic. Measured cross-node dispatch regressed from 1013us to 3063us when the unconditional grid bump was active. Keep the scaled grid for the IPC path (intra-node), where combine-recv and dispatch token striding scale with sm_id and the 1.2-1.3x speedup reproduces.
This commit is contained in:
@@ -441,18 +441,21 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms_base = cell_div(num_experts, kNumWarpGroups);
|
||||
// LL dispatch/combine are latency-bound at small problem sizes; grid=
|
||||
// cell_div(num_experts,3) leaves most SMs idle. Extra blocks beyond the
|
||||
// expert-owning ones skip the send/count phases (gated by
|
||||
// `responsible_expert_idx < num_experts`) and only participate in the
|
||||
// token-striding recv/combine bodies, where more blocks speed up the
|
||||
// per-token work linearly. Cap at the device SM count because the launch
|
||||
// is cooperative and `__launch_bounds__(num_warps*32, 1)` allows 1
|
||||
// block/SM.
|
||||
int device_num_sms = 0;
|
||||
int cur_dev = 0;
|
||||
cudaGetDevice(&cur_dev);
|
||||
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
|
||||
// LL dispatch/combine are latency-bound at typical problem sizes: for
|
||||
// num_experts=32 the base grid is cell_div(32,3)=11 blocks, i.e. 8% of a
|
||||
// 132-SM H100. The recv-side bodies stride tokens by `sm_id`, so extra
|
||||
// blocks parallelize token work linearly when the transport is cheap.
|
||||
//
|
||||
// Only enabled on the IPC path: on the PortChannel path each extra block
|
||||
// issues more concurrent PUTs into the host proxy FIFO, and the
|
||||
// cg::this_grid().sync() barrier between phases costs more with a larger
|
||||
// grid, which empirically regresses cross-node dispatch.
|
||||
int device_num_sms = num_sms_base;
|
||||
if (use_ipc_path) {
|
||||
int cur_dev = 0;
|
||||
cudaGetDevice(&cur_dev);
|
||||
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
|
||||
}
|
||||
const auto num_sms = std::max<int>(num_sms_base,
|
||||
std::min<int>(device_num_sms, std::max<int>(num_tokens, num_sms_base)));
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
@@ -688,11 +691,15 @@ void combine(void* combined_x,
|
||||
const auto num_sms_base = cell_div(num_experts, kNumWarpGroups);
|
||||
// See the comment in `dispatch()` above: combine-recv's per-token loop
|
||||
// strides by `sm_id`, so extra blocks parallelize the weighted reduction
|
||||
// linearly. Scale grid by `num_combined_tokens` and cap at device SMs.
|
||||
int device_num_sms = 0;
|
||||
int cur_dev = 0;
|
||||
cudaGetDevice(&cur_dev);
|
||||
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
|
||||
// linearly on the IPC path. Keep the baseline grid on the PortChannel
|
||||
// path to avoid the cooperative-sync / proxy-FIFO overhead that regressed
|
||||
// cross-node dispatch.
|
||||
int device_num_sms = num_sms_base;
|
||||
if (use_ipc_path) {
|
||||
int cur_dev = 0;
|
||||
cudaGetDevice(&cur_dev);
|
||||
cudaDeviceGetAttribute(&device_num_sms, cudaDevAttrMultiProcessorCount, cur_dev);
|
||||
}
|
||||
const auto num_sms = std::max<int>(num_sms_base,
|
||||
std::min<int>(device_num_sms,
|
||||
std::max<int>(num_combined_tokens, num_sms_base)));
|
||||
|
||||
Reference in New Issue
Block a user