diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index 22dd2da5..a0500f87 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -295,6 +295,40 @@ def main(): False, ) + # Run one uncached dispatch to capture the layout (rank/channel prefix + # matrices + num_recv_tokens). Subsequent dispatch iters reuse these in + # cached mode, which skips `notify_dispatch` and its host-side busy-wait on + # mapped pinned counters (`moe_recv_counter`, `moe_recv_expert_counter`). + # This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU + # dispatch kernel cost from one-time setup overhead. + _layout = _dispatch() + _cached_rpm = _layout[5] # rank_prefix_matrix + _cached_cpm = _layout[6] # channel_prefix_matrix + _cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank + + def _dispatch_cached(): + # In cached mode `num_experts` is taken as 0, so we must not pass + # topk_idx/topk_weights (those require num_experts > 0). We still get + # send_head/rank_prefix_matrix/channel_prefix_matrix/recv_src_idx out + # of dispatch -- enough to drive combine. + return buf.runtime.intranode_dispatch( + x_b, + None, + None, + None, + None, + is_token_in_rank_b, + None, + _cached_n, + _cached_rpm, + _cached_cpm, + 1, + cfg, + None, + False, + False, + ) + def _combine(dout): rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout buf.runtime.intranode_combine( @@ -310,19 +344,19 @@ def main(): False, ) - # Warmup (full round-trip). + # Warmup (full round-trip) using cached dispatch. for _ in range(warmup): - _combine(_dispatch()) + _combine(_dispatch_cached()) torch.cuda.synchronize() dist.barrier(group=group) - # Time dispatch alone. + # Time dispatch alone (cached mode -- skips notify_dispatch host wait). start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() dout = None for _ in range(iters): - dout = _dispatch() + dout = _dispatch_cached() end_ev.record() torch.cuda.synchronize() disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters