diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 7c3aafb7..a82b9f2f 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -108,10 +108,13 @@ def main(): x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank) - # Buffer config for internode HT: needs num_rdma_bytes > 0. + # Buffer config for internode HT: needs num_rdma_bytes > 0. Size buffers + # using max(hidden, bench_hidden) so the optional bench phase fits. cfg = ep.Config(20, 8, 256, 16, 128) - num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks) - num_rdma_bytes = cfg.get_rdma_buffer_size_hint(hidden * x.element_size(), num_ranks) + _bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1" + _buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden + num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) + num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) if rank == 0: print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} " f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} " diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index c9ff4002..d2365bae 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -96,9 +96,12 @@ def main(): # Token payload = rank id (cast to bf16) so we can check correctness x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank) - # Allocate Buffer (intranode only: num_rdma_bytes=0) + # Allocate Buffer (intranode only: num_rdma_bytes=0). Size the NVL buffer + # using max(hidden, bench_hidden) so the optional bench phase fits. cfg = ep.Config(20, 8, 256) - num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks) + _bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1" + _buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden + num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks) if rank == 0: print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} " f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",