Configure IPC domain per communicator

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-15 22:26:53 +00:00
parent ee82cc4c41
commit dbebde2b58
10 changed files with 41 additions and 27 deletions

View File

@@ -282,6 +282,8 @@ void register_core(nb::module_& m) {
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
.def("context", &Communicator::context)
.def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks"))
.def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {

View File

@@ -35,6 +35,7 @@ class CommGroup:
interfaceIpPortTrio: str = "",
rank: int = None,
size: int = None,
ipc_domain_n_ranks: int = 0,
):
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
uniq_id = None
@@ -70,9 +71,11 @@ class CommGroup:
else:
raise RuntimeError("Either the interface or mpi_group need to be specified")
self.communicator = CppCommunicator(self.bootstrap)
self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks()
def barrier(self):
self.bootstrap.barrier()