diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 94f64681..7c3aafb7 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -17,8 +17,9 @@ Launch on each node with (example: 2 nodes x 8 GPUs = 16 ranks): Round-trip dispatch + combine using internode HT kernels across nodes. Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass -(CUDA-event timed, reports per-iter latency and aggregate effective -bandwidth from rank 0). Override iteration counts with +that times dispatch and combine **separately** with CUDA events. Reports +per-phase latency (max across ranks) plus aggregate effective bandwidth +(sum across ranks). Override iteration counts with ``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``. """ @@ -261,55 +262,84 @@ def main(): is_token_in_rank_b = token_idx_in_rank_b >= 0 x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank) - def _one_iter(): - (rx, rxs, rti, rtw, _lst, - rpm, gpm, rrcpm, rrps, rgpm, rgps, - rsm, sh_rdma, sh_nvl, _ev) = buf.runtime.internode_dispatch( + def _dispatch(): + return buf.runtime.internode_dispatch( x_b, None, topk_idx_b, topk_weights_b, num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b, 0, 0, None, None, None, None, 1, cfg, None, False, False, ) - torch.cuda.synchronize() - dist.barrier(group=group) + + def _combine(dout): + (rx, _rxs, _rti, rtw, _lst, + _rpm, _gpm, rrcpm, rrps, rgpm, _rgps, + rsm, sh_rdma, sh_nvl, _ev) = dout buf.runtime.internode_combine( rx, rtw, rsm, is_token_in_rank_b, rrcpm, rrps, rgpm, sh_rdma, sh_nvl, cfg, None, False, False, ) - return rx.size(0) + # Warmup (full round-trip with the sync/barrier guard between phases, + # matching the correctness-path invariant). for _ in range(warmup): - _one_iter() + dout = _dispatch() + torch.cuda.synchronize() + dist.barrier(group=group) + _combine(dout) torch.cuda.synchronize() dist.barrier(group=group) + # Time dispatch alone. start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() - recv_tokens_total = 0 + dout = None for _ in range(iters): - recv_tokens_total += _one_iter() + dout = _dispatch() end_ev.record() torch.cuda.synchronize() - elapsed_ms = start_ev.elapsed_time(end_ev) - us_per_iter = elapsed_ms * 1e3 / iters + disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters + recv_tokens = dout[0].size(0) - avg_recv = recv_tokens_total / iters - bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size() - bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9 + # Required guard before combine sees the dispatch outputs (see correctness + # path's XXX note). Not included in either phase's timing. + torch.cuda.synchronize() + dist.barrier(group=group) - bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda") - us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda") - dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group) - dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group) + # Time combine alone (reusing the same dispatch output each iter). + start_ev.record() + for _ in range(iters): + _combine(dout) + end_ev.record() + torch.cuda.synchronize() + comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters + + bytes_one_way = recv_tokens * bench_hidden * x_b.element_size() + disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9 + comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9 + + disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda") + comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda") + disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda") + comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda") + dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group) if rank == 0: print( f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} " f"tokens={bench_tokens} hidden={bench_hidden} " - f"warmup={warmup} iters={iters} " - f"per-iter={us_t.item():.1f}us (max across ranks) " - f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)", + f"warmup={warmup} iters={iters}", + flush=True, + ) + print( + f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s", + flush=True, + ) + print( + f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s", flush=True, ) diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index a4d5c6dc..c9ff4002 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -7,9 +7,10 @@ Tests that Buffer::sync() succeeds across N GPUs on a single node and that a round-trip dispatch + combine preserves data (sum of top-k weighted copies). Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass -(N iterations with CUDA events; reports per-iter latency and effective -NVLink bandwidth from rank 0). Override iteration counts with -``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` / the bench +that times dispatch and combine **separately** with CUDA events and +reports per-phase latency (max across ranks) plus aggregate effective +NVLink bandwidth (sum across ranks). Override iteration counts with +``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``. This is a minimal adaptation of DeepEP's tests/test_intranode.py stripped @@ -221,52 +222,71 @@ def main(): is_token_in_rank_b = token_idx_in_rank_b >= 0 x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank) - def _one_iter(): - (rx, rxs, rti, rtw, _lst, rpm, cpm, rcpm, rsi, sh, _ev) = buf.runtime.intranode_dispatch( + def _dispatch(): + return buf.runtime.intranode_dispatch( x_b, None, topk_idx_b, topk_weights_b, num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b, 0, None, None, 1, cfg, None, False, False, ) + + def _combine(dout): + (rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev) = dout buf.runtime.intranode_combine( rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False, ) - return rx.size(0) - # Warmup + # Warmup (full round-trip). for _ in range(warmup): - _one_iter() + _combine(_dispatch()) torch.cuda.synchronize() dist.barrier(group=group) + # Time dispatch alone. start_ev = torch.cuda.Event(enable_timing=True) end_ev = torch.cuda.Event(enable_timing=True) start_ev.record() - recv_tokens_total = 0 + dout = None for _ in range(iters): - recv_tokens_total += _one_iter() + dout = _dispatch() end_ev.record() torch.cuda.synchronize() - elapsed_ms = start_ev.elapsed_time(end_ev) - us_per_iter = elapsed_ms * 1e3 / iters + disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters + recv_tokens = dout[0].size(0) - # Rough effective BW per rank: dispatched + combined payload bytes through - # comm. We treat each iter as sending `recv_tokens * hidden * elt_size` - # one way for dispatch and the same back for combine. - avg_recv = recv_tokens_total / iters - bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size() - bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9 + # Time combine alone (reusing the same dispatch output each iter). + dist.barrier(group=group) + start_ev.record() + for _ in range(iters): + _combine(dout) + end_ev.record() + torch.cuda.synchronize() + comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters - # Aggregate across ranks - bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda") - us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda") - dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group) - dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group) + # One-way payload bytes (per phase) per rank. + bytes_one_way = recv_tokens * bench_hidden * x_b.element_size() + disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9 + comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9 + + disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda") + comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda") + disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda") + comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda") + dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group) if rank == 0: print( - f"[bench intranode] tokens={bench_tokens} hidden={bench_hidden} " - f"warmup={warmup} iters={iters} " - f"per-iter={us_t.item():.1f}us (max across ranks) " - f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)", + f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} " + f"warmup={warmup} iters={iters}", + flush=True, + ) + print( + f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s", + flush=True, + ) + print( + f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s", flush=True, ) diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index 0b3fd660..1e9f6ba5 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -203,6 +203,89 @@ def main(): if rank == 0: print("PASS", flush=True) + # ------------------------------------------------------------------ + # Optional benchmark (enable with MSCCLPP_EP_BENCH=1). Times dispatch + # and combine separately, reporting per-iter latency (max across ranks) + # and aggregate effective bandwidth (sum across ranks). + # ------------------------------------------------------------------ + if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1": + return + + warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5")) + iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20")) + + def _dispatch(): + return buf.low_latency_dispatch( + x, topk_idx, num_tokens, num_experts, + False, False, False, # use_fp8, async, return_recv_hook + ) + + def _combine(dout): + (recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout + out_ = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + buf.low_latency_combine( + recv_x.clone(), topk_idx, topk_weights, + src_info_, layout_range_, + num_tokens, num_experts, + False, False, False, + out_, + ) + + for _ in range(warmup): + _combine(_dispatch()) + torch.cuda.synchronize() + dist.barrier(group=group) + + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + start_ev.record() + dout = None + for _ in range(iters): + dout = _dispatch() + end_ev.record() + torch.cuda.synchronize() + disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters + recv_tokens = int(dout[2].sum().item()) # packed_recv_count summed over local experts + + dist.barrier(group=group) + start_ev.record() + for _ in range(iters): + _combine(dout) + end_ev.record() + torch.cuda.synchronize() + comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters + + # Dispatch payload: recv_tokens × hidden × bf16 (received on this rank). + # Combine payload: num_tokens × hidden × bf16 (sent from each local expert + # back to the owning rank; one token's worth of bytes per reduction). + disp_bytes = recv_tokens * hidden * 2 + comb_bytes = num_tokens * hidden * 2 + disp_bw = disp_bytes / (disp_us * 1e-6) / 1e9 + comb_bw = comb_bytes / (comb_us * 1e-6) / 1e9 + + disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda") + comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda") + disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda") + comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda") + dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group) + if rank == 0: + print( + f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} " + f"num_experts={num_experts} warmup={warmup} iters={iters}", + flush=True, + ) + print( + f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s", + flush=True, + ) + print( + f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s", + flush=True, + ) + if __name__ == "__main__": try: