mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Removing MPI Dependency (#743)
This commit is contained in:
@@ -19,8 +19,8 @@ from mscclpp._mscclpp import (
|
||||
CppTransport,
|
||||
CppTransportFlags,
|
||||
)
|
||||
import mpi4py
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
from mscclpp.utils import is_torch_tensor
|
||||
|
||||
@@ -29,20 +29,35 @@ __all__ = ["CommGroup"]
|
||||
|
||||
class CommGroup:
|
||||
def __init__(
|
||||
self, mpi_comm: mpi4py.MPI.Comm = None, interfaceIpPortTrio: str = "", rank: int = None, size: int = None
|
||||
self,
|
||||
mpi_comm: "mpi4py.MPI.Comm" = None,
|
||||
torch_group: "dist.ProcessGroup" = None,
|
||||
interfaceIpPortTrio: str = "",
|
||||
rank: int = None,
|
||||
size: int = None,
|
||||
):
|
||||
if interfaceIpPortTrio == "":
|
||||
self.bootstrap = CppTcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
|
||||
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
|
||||
uniq_id = None
|
||||
if mpi_comm.rank == 0:
|
||||
# similar to NCCL's unique id
|
||||
self.bootstrap = CppTcpBootstrap.create(rank, size)
|
||||
if rank == 0:
|
||||
uniq_id = self.bootstrap.create_unique_id()
|
||||
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
|
||||
if mpi_comm is not None:
|
||||
import mpi4py
|
||||
|
||||
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
|
||||
else:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
if rank == 0:
|
||||
uniq_id_global = uniq_id
|
||||
pickled_data = pickle.dumps(uniq_id)
|
||||
data_tensor = torch.frombuffer(bytearray(pickled_data), dtype=torch.uint8).clone()
|
||||
else:
|
||||
data_tensor = torch.zeros(256, dtype=torch.uint8)
|
||||
dist.broadcast(data_tensor, src=0, group=torch_group)
|
||||
uniq_id_global = pickle.loads(data_tensor.numpy().tobytes())
|
||||
self.bootstrap.initialize(uniq_id_global)
|
||||
elif mpi_comm:
|
||||
# use this instead
|
||||
self.bootstrap = CppTcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
|
||||
self.bootstrap.initialize(interfaceIpPortTrio)
|
||||
elif not interfaceIpPortTrio == "":
|
||||
assert rank >= 0 and size >= 1
|
||||
self.bootstrap = CppTcpBootstrap.create(rank, size)
|
||||
|
||||
Reference in New Issue
Block a user