diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index 0fb0eca9..fd9c8c6d 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -221,11 +221,16 @@ def main(): False, False, False, # use_fp8, async, return_recv_hook ) - def _combine(dout): + # Hoist combine's output-tensor allocation out of the timed loop so the + # measurement reflects the kernel cost. (The original test also cloned the + # ~58 MB dispatch recv buffer on every iter, adding ~20 us of D2D memcpy + # to each combine sample and masking kernel-level changes.) + bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + + def _combine(dout, out_): (recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout - out_ = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") buf.low_latency_combine( - recv_x.clone(), topk_idx, topk_weights, + recv_x, topk_idx, topk_weights, src_info_, layout_range_, num_tokens, num_experts, False, False, False, @@ -233,7 +238,7 @@ def main(): ) for _ in range(warmup): - _combine(_dispatch()) + _combine(_dispatch(), bench_out) torch.cuda.synchronize() dist.barrier(group=group) @@ -251,7 +256,7 @@ def main(): dist.barrier(group=group) start_ev.record() for _ in range(iters): - _combine(dout) + _combine(dout, bench_out) end_ev.record() torch.cuda.synchronize() comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters