mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
detect ipc domain automaticlly
This commit is contained in:
@@ -56,6 +56,7 @@ void register_core(nb::module_& m) {
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
||||
.def("get_n_ranks_per_ipc_domain", &Bootstrap::getNranksPerIpcDomain)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
@@ -282,8 +283,6 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks"))
|
||||
.def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
|
||||
@@ -35,7 +35,6 @@ class CommGroup:
|
||||
interfaceIpPortTrio: str = "",
|
||||
rank: int = None,
|
||||
size: int = None,
|
||||
ipc_domain_n_ranks: int = 0,
|
||||
):
|
||||
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
|
||||
uniq_id = None
|
||||
@@ -71,11 +70,10 @@ class CommGroup:
|
||||
else:
|
||||
raise RuntimeError("Either the interface or mpi_group need to be specified")
|
||||
self.communicator = CppCommunicator(self.bootstrap)
|
||||
self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks)
|
||||
self.my_rank = self.bootstrap.get_rank()
|
||||
self.nranks = self.bootstrap.get_n_ranks()
|
||||
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
|
||||
self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks()
|
||||
self.ipc_domain_n_ranks = self.bootstrap.get_n_ranks_per_ipc_domain()
|
||||
|
||||
def barrier(self):
|
||||
self.bootstrap.barrier()
|
||||
|
||||
Reference in New Issue
Block a user