mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Update the benchmark to improve the rank mapping, communicator creation, backend selection
This commit is contained in:
@@ -25,17 +25,25 @@ def main():
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
|
||||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
|
||||
|
||||
# Set CUDA device
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
|
||||
# Set CUDA device — prefer MPI-provided local rank to handle any rank mapping
|
||||
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",
|
||||
os.environ.get("MPI_LOCALRANKID",
|
||||
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT)
|
||||
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
|
||||
# otherwise gloo avoids IB configuration issues on some clusters.
|
||||
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
|
||||
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
os.environ.setdefault("MASTER_PORT", "29500")
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
|
||||
device_id=torch.device(f"cuda:{local_rank}"))
|
||||
if backend == "nccl":
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
|
||||
device_id=torch.device(f"cuda:{local_rank}"))
|
||||
else:
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
|
||||
@@ -48,31 +56,20 @@ def main():
|
||||
UniqueId,
|
||||
)
|
||||
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
|
||||
import pickle
|
||||
from mpi4py import MPI
|
||||
mpi_comm = MPI.COMM_WORLD
|
||||
|
||||
# Create mscclpp communicator with TcpBootstrap
|
||||
# Use torch.distributed to share the unique ID via pickle
|
||||
# Broadcast UniqueId raw bytes (128 bytes) via MPI to avoid NCCL interception issues
|
||||
bootstrap = TcpBootstrap(rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
unique_id = bootstrap.create_unique_id()
|
||||
# Serialize UniqueId via pickle and broadcast
|
||||
pickled = pickle.dumps(unique_id)
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8)
|
||||
# Also send length
|
||||
len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda')
|
||||
else:
|
||||
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
|
||||
len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda')
|
||||
unique_id = UniqueId()
|
||||
|
||||
dist.broadcast(len_tensor, src=0)
|
||||
dist.broadcast(id_tensor, src=0)
|
||||
|
||||
if rank != 0:
|
||||
pickled_len = int(len_tensor.item())
|
||||
pickled = bytes(id_tensor[:pickled_len].cpu().tolist())
|
||||
unique_id = pickle.loads(pickled)
|
||||
# UniqueId supports pickle (__getstate__/__setstate__), MPI bcast uses pickle
|
||||
unique_id = mpi_comm.bcast(unique_id, root=0)
|
||||
|
||||
bootstrap.initialize(unique_id)
|
||||
comm = Communicator(bootstrap)
|
||||
|
||||
Reference in New Issue
Block a user