mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Integrate MSCCL++ with torch Introduce `NCCL audit shim library`, use can use following commands to launch torch library. Also avoid break build pipeline in the CPU machine ```bash export LD_AUDIT=$MSCCLPP_INSTALL_DIR/libmscclpp_audit_nccl.so export LD_LIBRARY_PATH=$MSCCLPP_INSTALL_DIR:$LD_LIBRARY_PATH torchrun --nnodes=1 --nproc_per_node=8 your_script.py ```
181 lines
6.9 KiB
Python
181 lines
6.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
"""Collective correctness verification using PyTorch distributed.
|
|
|
|
Run examples:
|
|
torchrun --nproc_per_node=4 test/torch/correctness_test.py --collective allreduce --nelem 1048576 --dtype fp16
|
|
torchrun --nproc_per_node=4 test/torch/correctness_test.py --collective allgather --dtype bfloat16
|
|
torchrun --nproc_per_node=4 test/torch/correctness_test.py --collective reduce_scatter --dtype float32
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import os
|
|
import argparse
|
|
import torch
|
|
import torch.distributed as dist
|
|
from typing import Tuple
|
|
|
|
_A = 1664525
|
|
_C = 1013904223
|
|
_MASK = 0xFFFFFFFF
|
|
_N_DIFFERENT_FLOAT = 4096
|
|
|
|
|
|
def _parse_dtype(name: str) -> torch.dtype:
|
|
name = name.lower()
|
|
if name in {"fp32", "float", "float32", "f32"}:
|
|
return torch.float32
|
|
if name in {"fp16", "half", "float16"}:
|
|
return torch.float16
|
|
if name in {"bf16", "bfloat16"}:
|
|
return torch.bfloat16
|
|
if name in {"int32", "i32"}:
|
|
return torch.int32
|
|
raise ValueError(f"Unsupported dtype: {name}")
|
|
|
|
|
|
def _default_tolerances(dtype: torch.dtype) -> Tuple[float, float]:
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
return 5e-3, 1e-3
|
|
return 1e-4, 1e-5
|
|
|
|
|
|
def generate_rank_tensor(
|
|
num_elems: int, rank: int, seq: int = 0, device: torch.device | None = None, *, dtype: torch.dtype = torch.float32
|
|
) -> torch.Tensor:
|
|
"""Generate deterministic pseudo-random values in [0,1) (for integer types, scaled & cast)."""
|
|
if device is None:
|
|
device = torch.device("cuda", rank % torch.cuda.device_count())
|
|
seeds = (torch.arange(num_elems, device=device, dtype=torch.int64) + rank + seq) & _MASK
|
|
seeds = (seeds * _A + _C) & _MASK
|
|
base = (seeds.remainder(_N_DIFFERENT_FLOAT).to(torch.float32)) / float(_N_DIFFERENT_FLOAT)
|
|
if dtype.is_floating_point:
|
|
return base.to(dtype)
|
|
# Integer path: scale to 0..(2^31-1) approximately then cast
|
|
return (base * (2**31 - 1)).to(dtype)
|
|
|
|
|
|
def _init_dist():
|
|
if dist.is_initialized():
|
|
return
|
|
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
|
|
raise RuntimeError("Distributed environment variables not set. Run with torchrun.")
|
|
backend = "nccl"
|
|
dist.init_process_group(backend=backend)
|
|
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
|
|
def run_allreduce_test(num_elems: int, iters: int, dtype: torch.dtype, rtol: float, atol: float):
|
|
_init_dist()
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
for iter in range(iters):
|
|
x = generate_rank_tensor(num_elems, rank, seq=iter, dtype=dtype)
|
|
expected = torch.empty_like(x)
|
|
for r in range(world_size):
|
|
t = generate_rank_tensor(num_elems, r, seq=iter, dtype=dtype, device=x.device)
|
|
expected = t if r == 0 else expected.add(t)
|
|
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
|
_assert_close(x, expected, f"AllReduce (rank {rank}, iter {iter})", rtol, atol)
|
|
|
|
|
|
def run_allgather_test(num_elems: int, iters: int, dtype: torch.dtype, rtol: float, atol: float):
|
|
_init_dist()
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
for iter in range(iters):
|
|
inp = generate_rank_tensor(num_elems, rank, seq=iter, dtype=dtype)
|
|
out = torch.empty(world_size * num_elems, dtype=inp.dtype, device=inp.device)
|
|
out_views = [out[i * num_elems : (i + 1) * num_elems] for i in range(world_size)]
|
|
expected = torch.cat(
|
|
[generate_rank_tensor(num_elems, r, seq=iter, dtype=dtype, device=inp.device) for r in range(world_size)],
|
|
dim=0,
|
|
)
|
|
dist.all_gather(out_views, inp)
|
|
_assert_close(out, expected, f"AllGather (rank {rank})", rtol, atol)
|
|
|
|
|
|
def run_reducescatter_test(num_elems: int, iters: int, dtype: torch.dtype, rtol: float, atol: float):
|
|
_init_dist()
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
for i in range(iters):
|
|
input = generate_rank_tensor(num_elems * world_size, rank, seq=i, dtype=dtype)
|
|
output = torch.empty(num_elems, dtype=input.dtype, device=input.device)
|
|
input_list = list(input.chunk(world_size))
|
|
|
|
expected = None
|
|
for r in range(world_size):
|
|
t = generate_rank_tensor(num_elems * world_size, r, seq=i, dtype=dtype, device=output.device)
|
|
expected = t if expected is None else expected.add(t)
|
|
expected = expected.chunk(world_size)[rank]
|
|
dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)
|
|
_assert_close(output, expected, f"ReduceScatter (rank {rank})", rtol, atol)
|
|
|
|
|
|
def _assert_close(result: torch.Tensor, expected: torch.Tensor, context: str, rtol: float, atol: float):
|
|
# Promote for comparison when needed
|
|
if result.dtype != torch.float32:
|
|
result_f = result.to(torch.float32)
|
|
else:
|
|
result_f = result
|
|
if expected.dtype != torch.float32:
|
|
expected_f = expected.to(torch.float32)
|
|
else:
|
|
expected_f = expected
|
|
if not torch.allclose(result_f, expected_f, rtol=rtol, atol=atol):
|
|
max_abs = (result_f - expected_f).abs().max().item()
|
|
rel = max_abs / (expected_f.abs().max().item() + 1e-12)
|
|
raise AssertionError(f"{context} failed: max_abs={max_abs:.3e} rel={rel:.3e} (rtol={rtol} atol={atol})")
|
|
assert torch.isfinite(result_f).all(), f"{context} produced non-finite values"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MSCCL++ torch CUDA graph collective correctness tester")
|
|
parser.add_argument("--collective", choices=["allreduce", "allgather", "reduce_scatter"], default="allreduce")
|
|
parser.add_argument(
|
|
"--num-elems",
|
|
"--nelem",
|
|
dest="num_elems",
|
|
type=int,
|
|
default=1 << 18,
|
|
help="Elements per rank (or per chunk for reduce_scatter)",
|
|
)
|
|
parser.add_argument(
|
|
"--iters", type=int, default=4, help="Number of collective iterations captured in the CUDA graph"
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default="float32",
|
|
help="Data type: float32|fp16|bfloat16|int32 (only float dtypes fully validated)",
|
|
)
|
|
parser.add_argument("--rtol", type=float, default=None, help="Override relative tolerance")
|
|
parser.add_argument("--atol", type=float, default=None, help="Override absolute tolerance")
|
|
args = parser.parse_args()
|
|
|
|
dtype = _parse_dtype(args.dtype)
|
|
rtol, atol = _default_tolerances(dtype)
|
|
if args.rtol is not None:
|
|
rtol = args.rtol
|
|
if args.atol is not None:
|
|
atol = args.atol
|
|
|
|
if args.collective == "allreduce":
|
|
run_allreduce_test(args.num_elems, args.iters, dtype, rtol, atol)
|
|
elif args.collective == "allgather":
|
|
run_allgather_test(args.num_elems, args.iters, dtype, rtol, atol)
|
|
elif args.collective == "reduce_scatter":
|
|
run_reducescatter_test(args.num_elems, args.iters, dtype, rtol, atol)
|
|
else:
|
|
raise ValueError("Unknown collective")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|