detect ipc domain automaticlly

This commit is contained in:
Binyang Li
2026-05-16 00:39:49 +00:00
parent 93b43547cc
commit 0744e806fc
23 changed files with 130 additions and 49 deletions

View File

@@ -56,6 +56,7 @@ void register_core(nb::module_& m) {
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
.def("get_n_ranks_per_ipc_domain", &Bootstrap::getNranksPerIpcDomain)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
@@ -282,8 +283,6 @@ void register_core(nb::module_& m) {
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
.def("context", &Communicator::context)
.def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks"))
.def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {

View File

@@ -35,7 +35,6 @@ class CommGroup:
interfaceIpPortTrio: str = "",
rank: int = None,
size: int = None,
ipc_domain_n_ranks: int = 0,
):
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
uniq_id = None
@@ -71,11 +70,10 @@ class CommGroup:
else:
raise RuntimeError("Either the interface or mpi_group need to be specified")
self.communicator = CppCommunicator(self.bootstrap)
self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks()
self.ipc_domain_n_ranks = self.bootstrap.get_n_ranks_per_ipc_domain()
def barrier(self):
self.bootstrap.barrier()