From f9f0d0fcb7dbea0795b9cb99729948fa71fbe3aa Mon Sep 17 00:00:00 2001 From: qinghuazhou Date: Tue, 12 May 2026 19:02:31 +0000 Subject: [PATCH] =?UTF-8?q?test/ext/ep:=20intranode=20HT=20bench=20?= =?UTF-8?q?=E2=80=94=20cached-mode=20iter=20loop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run one uncached dispatch to capture (rank_prefix_matrix, channel_prefix_matrix, num_recv_tokens), then time iters in cached mode. This replaces notify_dispatch + host busy-wait on mapped pinned counters with the cheap cached_notify_dispatch (one barrier + memcpy + memset), matching NCCL-EP ep_bench convention. Cached mode forces num_experts=0 (buffer.cc:807), so topk_idx must be None in iters; combine still works because recv_topk_weights is optional. Per-iter dispatch latency drops ~21% (4247→3373µs). Confirms host-side notify_dispatch overhead is only ~20% of total dispatch time; the remaining 14.4× send-total asymmetry vs combine is intrinsic (3× recv/ send byte fan-out × 3.8× dispatch-kernel-vs-combine-kernel work). --- .../python/ext/ep/test_intranode_multirank.py | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index 22dd2da5..a0500f87 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -295,6 +295,40 @@ def main(): False, ) + # Run one uncached dispatch to capture the layout (rank/channel prefix + # matrices + num_recv_tokens). Subsequent dispatch iters reuse these in + # cached mode, which skips `notify_dispatch` and its host-side busy-wait on + # mapped pinned counters (`moe_recv_counter`, `moe_recv_expert_counter`). + # This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU + # dispatch kernel cost from one-time setup overhead. + _layout = _dispatch() + _cached_rpm = _layout[5] # rank_prefix_matrix + _cached_cpm = _layout[6] # channel_prefix_matrix + _cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank + + def _dispatch_cached(): + # In cached mode `num_experts` is taken as 0, so we must not pass + # topk_idx/topk_weights (those require num_experts > 0). We still get + # send_head/rank_prefix_matrix/channel_prefix_matrix/recv_src_idx out + # of dispatch -- enough to drive combine. + return buf.runtime.intranode_dispatch( + x_b, + None, + None, + None, + None, + is_token_in_rank_b, + None, + _cached_n, + _cached_rpm, + _cached_cpm, + 1, + cfg, + None, + False, + False, + ) + def _combine(dout): rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout buf.runtime.intranode_combine( @@ -310,19 +344,19 @@ def main(): False, ) - # Warmup (full round-trip). + # Warmup (full round-trip) using cached dispatch. for _ in range(warmup): - _combine(_dispatch()) + _combine(_dispatch_cached()) torch.cuda.synchronize() dist.barrier(group=group) - # Time dispatch alone. + # Time dispatch alone (cached mode -- skips notify_dispatch host wait). start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() dout = None for _ in range(iters): - dout = _dispatch() + dout = _dispatch_cached() end_ev.record() torch.cuda.synchronize() disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters