mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
tests/ep: size HT buffers for bench hidden so bench phase fits
This commit is contained in:
@@ -108,10 +108,13 @@ def main():
|
||||
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
# Buffer config for internode HT: needs num_rdma_bytes > 0.
|
||||
# Buffer config for internode HT: needs num_rdma_bytes > 0. Size buffers
|
||||
# using max(hidden, bench_hidden) so the optional bench phase fits.
|
||||
cfg = ep.Config(20, 8, 256, 16, 128)
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks)
|
||||
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(hidden * x.element_size(), num_ranks)
|
||||
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
|
||||
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
num_rdma_bytes = cfg.get_rdma_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
if rank == 0:
|
||||
print(f"[cfg] num_nodes={num_nodes} num_ranks={num_ranks} num_tokens={num_tokens} "
|
||||
f"hidden={hidden} num_experts={num_experts} num_topk={num_topk} "
|
||||
|
||||
@@ -96,9 +96,12 @@ def main():
|
||||
# Token payload = rank id (cast to bf16) so we can check correctness
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
# Allocate Buffer (intranode only: num_rdma_bytes=0)
|
||||
# Allocate Buffer (intranode only: num_rdma_bytes=0). Size the NVL buffer
|
||||
# using max(hidden, bench_hidden) so the optional bench phase fits.
|
||||
cfg = ep.Config(20, 8, 256)
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(hidden * x.element_size(), num_ranks)
|
||||
_bench_on = os.environ.get("MSCCLPP_EP_BENCH", "0") == "1"
|
||||
_buf_hidden = max(hidden, int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "0"))) if _bench_on else hidden
|
||||
num_nvl_bytes = cfg.get_nvl_buffer_size_hint(_buf_hidden * x.element_size(), num_ranks)
|
||||
if rank == 0:
|
||||
print(f"[cfg] num_ranks={num_ranks} num_tokens={num_tokens} hidden={hidden} "
|
||||
f"num_experts={num_experts} num_topk={num_topk} num_nvl_bytes={num_nvl_bytes}",
|
||||
|
||||
Reference in New Issue
Block a user