tests/ep: size HT buffers for bench hidden so bench phase fits

This commit is contained in:
Qinghua Zhou
2026-04-23 17:13:09 +00:00
parent 441bfa5265
commit 906fa3c48f
2 changed files with 11 additions and 5 deletions

View File

@@ -108,10 +108,13 @@ def main():
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
# Buffer config for internode HT: needs num_rdma_bytes > 0.
# Buffer config for internode HT: needs num_rdma_bytes > 0. Size buffers
# using max(hidden, bench_hidden) so the optional bench phase fits.
cfg = ep.Config(20, 8, 256, 16, 128)
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks)
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(hidden * x.element_size(), num_ranks)
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
if rank == 0:
print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "

View File

@@ -96,9 +96,12 @@ def main():
# Token payload = rank id (cast to bf16) so we can check correctness
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
# Allocate Buffer (intranode only: num_rdma_bytes=0)
# Allocate Buffer (intranode only: num_rdma_bytes=0). Size the NVL buffer
# using max(hidden, bench_hidden) so the optional bench phase fits.
cfg = ep.Config(20, 8, 256)
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks)
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
if rank == 0:
print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",