diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index 1e9f6ba5..d82ea309 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -256,10 +256,12 @@ def main(): comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters # Dispatch payload: recv_tokens × hidden × bf16 (received on this rank). - # Combine payload: num_tokens × hidden × bf16 (sent from each local expert - # back to the owning rank; one token's worth of bytes per reduction). + # Combine payload: recv_tokens × hidden × bf16 as well -- each local expert + # sends one copy per dispatched token back to its owner, so the bytes on + # the wire match dispatch. Using num_tokens × hidden here would under-count + # the actual send payload by ~num_topk×. disp_bytes = recv_tokens * hidden * 2 - comb_bytes = num_tokens * hidden * 2 + comb_bytes = recv_tokens * hidden * 2 disp_bw = disp_bytes / (disp_us * 1e-6) / 1e9 comb_bw = comb_bytes / (comb_us * 1e-6) / 1e9