Removing MPI Dependency (#743)

This commit is contained in:
Caio Rocha
2026-02-20 16:04:12 -08:00
committed by GitHub
parent 39865c218b
commit e2acf7f1c8
8 changed files with 34 additions and 19 deletions

View File

@@ -19,8 +19,8 @@ from mscclpp._mscclpp import (
CppTransport,
CppTransportFlags,
)
import mpi4py
import numpy as np
import pickle
from mscclpp.utils import is_torch_tensor
@@ -29,20 +29,35 @@ __all__ = ["CommGroup"]
class CommGroup:
def __init__(
self, mpi_comm: mpi4py.MPI.Comm = None, interfaceIpPortTrio: str = "", rank: int = None, size: int = None
self,
mpi_comm: "mpi4py.MPI.Comm" = None,
torch_group: "dist.ProcessGroup" = None,
interfaceIpPortTrio: str = "",
rank: int = None,
size: int = None,
):
if interfaceIpPortTrio == "":
self.bootstrap = CppTcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
uniq_id = None
if mpi_comm.rank == 0:
# similar to NCCL's unique id
self.bootstrap = CppTcpBootstrap.create(rank, size)
if rank == 0:
uniq_id = self.bootstrap.create_unique_id()
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
if mpi_comm is not None:
import mpi4py
uniq_id_global = mpi_comm.bcast(uniq_id, 0)
else:
import torch
import torch.distributed as dist
if rank == 0:
uniq_id_global = uniq_id
pickled_data = pickle.dumps(uniq_id)
data_tensor = torch.frombuffer(bytearray(pickled_data), dtype=torch.uint8).clone()
else:
data_tensor = torch.zeros(256, dtype=torch.uint8)
dist.broadcast(data_tensor, src=0, group=torch_group)
uniq_id_global = pickle.loads(data_tensor.numpy().tobytes())
self.bootstrap.initialize(uniq_id_global)
elif mpi_comm:
# use this instead
self.bootstrap = CppTcpBootstrap.create(mpi_comm.rank, mpi_comm.size)
self.bootstrap.initialize(interfaceIpPortTrio)
elif not interfaceIpPortTrio == "":
assert rank >= 0 and size >= 1
self.bootstrap = CppTcpBootstrap.create(rank, size)