mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
tests/ep: hoist combine output tensor out of the timed loop
The LL combine benchmark was cloning the ~58 MB dispatch recv buffer
('recv_x.clone()') on every timed iteration, adding ~20 us of D2D
memcpy per sample and masking kernel-level changes. It also called
torch.empty() for the output inside the loop. Both now live outside
the timed region; the kernel is invoked against a persistent bench_out
and the recv_x produced by the most recent dispatch.
This commit is contained in:
@@ -221,11 +221,16 @@ def main():
|
||||
False, False, False, # use_fp8, async, return_recv_hook
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
# Hoist combine's output-tensor allocation out of the timed loop so the
|
||||
# measurement reflects the kernel cost. (The original test also cloned the
|
||||
# ~58 MB dispatch recv buffer on every iter, adding ~20 us of D2D memcpy
|
||||
# to each combine sample and masking kernel-level changes.)
|
||||
bench_out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
def _combine(dout, out_):
|
||||
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
|
||||
out_ = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
buf.low_latency_combine(
|
||||
recv_x.clone(), topk_idx, topk_weights,
|
||||
recv_x, topk_idx, topk_weights,
|
||||
src_info_, layout_range_,
|
||||
num_tokens, num_experts,
|
||||
False, False, False,
|
||||
@@ -233,7 +238,7 @@ def main():
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
_combine(_dispatch())
|
||||
_combine(_dispatch(), bench_out)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
@@ -251,7 +256,7 @@ def main():
|
||||
dist.barrier(group=group)
|
||||
start_ev.record()
|
||||
for _ in range(iters):
|
||||
_combine(dout)
|
||||
_combine(dout, bench_out)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
Reference in New Issue
Block a user