mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 07:14:40 +00:00
test/ext/ep: intranode HT bench — cached-mode iter loop
Run one uncached dispatch to capture (rank_prefix_matrix, channel_prefix_matrix, num_recv_tokens), then time iters in cached mode. This replaces notify_dispatch + host busy-wait on mapped pinned counters with the cheap cached_notify_dispatch (one barrier + memcpy + memset), matching NCCL-EP ep_bench convention. Cached mode forces num_experts=0 (buffer.cc:807), so topk_idx must be None in iters; combine still works because recv_topk_weights is optional. Per-iter dispatch latency drops ~21% (4247→3373µs). Confirms host-side notify_dispatch overhead is only ~20% of total dispatch time; the remaining 14.4× send-total asymmetry vs combine is intrinsic (3× recv/ send byte fan-out × 3.8× dispatch-kernel-vs-combine-kernel work).
This commit is contained in:
@@ -295,6 +295,40 @@ def main():
|
||||
False,
|
||||
)
|
||||
|
||||
# Run one uncached dispatch to capture the layout (rank/channel prefix
|
||||
# matrices + num_recv_tokens). Subsequent dispatch iters reuse these in
|
||||
# cached mode, which skips `notify_dispatch` and its host-side busy-wait on
|
||||
# mapped pinned counters (`moe_recv_counter`, `moe_recv_expert_counter`).
|
||||
# This matches NCCL-EP's `ep_bench` convention and isolates the on-GPU
|
||||
# dispatch kernel cost from one-time setup overhead.
|
||||
_layout = _dispatch()
|
||||
_cached_rpm = _layout[5] # rank_prefix_matrix
|
||||
_cached_cpm = _layout[6] # channel_prefix_matrix
|
||||
_cached_n = int(_layout[0].size(0)) # num_recv_tokens on this rank
|
||||
|
||||
def _dispatch_cached():
|
||||
# In cached mode `num_experts` is taken as 0, so we must not pass
|
||||
# topk_idx/topk_weights (those require num_experts > 0). We still get
|
||||
# send_head/rank_prefix_matrix/channel_prefix_matrix/recv_src_idx out
|
||||
# of dispatch -- enough to drive combine.
|
||||
return buf.runtime.intranode_dispatch(
|
||||
x_b,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
is_token_in_rank_b,
|
||||
None,
|
||||
_cached_n,
|
||||
_cached_rpm,
|
||||
_cached_cpm,
|
||||
1,
|
||||
cfg,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev = dout
|
||||
buf.runtime.intranode_combine(
|
||||
@@ -310,19 +344,19 @@ def main():
|
||||
False,
|
||||
)
|
||||
|
||||
# Warmup (full round-trip).
|
||||
# Warmup (full round-trip) using cached dispatch.
|
||||
for _ in range(warmup):
|
||||
_combine(_dispatch())
|
||||
_combine(_dispatch_cached())
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# Time dispatch alone.
|
||||
# Time dispatch alone (cached mode -- skips notify_dispatch host wait).
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
dout = None
|
||||
for _ in range(iters):
|
||||
dout = _dispatch()
|
||||
dout = _dispatch_cached()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
Reference in New Issue
Block a user