test/ext/ep: intranode HT bench — cached-mode iter loop

Run one uncached dispatch to capture (rank_prefix_matrix,
channel_prefix_matrix, num_recv_tokens), then time iters in cached mode.
This replaces notify_dispatch + host busy-wait on mapped pinned counters
with the cheap cached_notify_dispatch (one barrier + memcpy + memset),
matching NCCL-EP ep_bench convention.

Cached mode forces num_experts=0 (buffer.cc:807), so topk_idx must be
None in iters; combine still works because recv_topk_weights is optional.

Per-iter dispatch latency drops ~21% (4247→3373µs). Confirms host-side
notify_dispatch overhead is only ~20% of total dispatch time; the
remaining 14.4× send-total asymmetry vs combine is intrinsic (3× recv/
send byte fan-out × 3.8× dispatch-kernel-vs-combine-kernel work).
This commit is contained in:
qinghuazhou
2026-05-12 19:02:31 +00:00
parent 13babbfff2
commit f9f0d0fcb7

View File

@@ -295,6 +295,40 @@ def main():
False,
)
# Run one uncached dispatch to capture the layout (rank/channel prefix
# matrices + num_recv_tokens). Subsequent dispatch iters reuse these in
# cached mode, which skips `notify_dispatch` and its host-side busy-wait on
# mapped pinned counters (`moe_recv_counter`, `moe_recv_expert_counter`).
# This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU
# dispatch kernel cost from one-time setup overhead.
_layout = _dispatch()
_cached_rpm = _layout[5] # rank_prefix_matrix
_cached_cpm = _layout[6] # channel_prefix_matrix
_cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank
def _dispatch_cached():
# In cached mode `num_experts` is taken as 0, so we must not pass
# topk_idx/topk_weights (those require num_experts > 0). We still get
# send_head/rank_prefix_matrix/channel_prefix_matrix/recv_src_idx out
# of dispatch -- enough to drive combine.
return buf.runtime.intranode_dispatch(
x_b,
None,
None,
None,
None,
is_token_in_rank_b,
None,
_cached_n,
_cached_rpm,
_cached_cpm,
1,
cfg,
None,
False,
False,
)
def _combine(dout):
rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout
buf.runtime.intranode_combine(
@@ -310,19 +344,19 @@ def main():
False,
)
# Warmup (full round-trip).
# Warmup (full round-trip) using cached dispatch.
for _ in range(warmup):
_combine(_dispatch())
_combine(_dispatch_cached())
torch.cuda.synchronize()
dist.barrier(group=group)
# Time dispatch alone.
# Time dispatch alone (cached mode -- skips notify_dispatch host wait).
start_ev = torch.cuda.Event(enable_timing=True)
end_ev = torch.cuda.Event(enable_timing=True)
start_ev.record()
dout = None
for _ in range(iters):
dout = _dispatch()
dout = _dispatch_cached()
end_ev.record()
torch.cuda.synchronize()
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters