mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 15:24:43 +00:00
Configure IPC domain per communicator
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -282,6 +282,8 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def("set_ipc_domain_n_ranks", &Communicator::setIpcDomainNranks, nb::arg("n_ranks"))
|
||||
.def("get_ipc_domain_n_ranks", &Communicator::getIpcDomainNranks)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
|
||||
@@ -35,6 +35,7 @@ class CommGroup:
|
||||
interfaceIpPortTrio: str = "",
|
||||
rank: int = None,
|
||||
size: int = None,
|
||||
ipc_domain_n_ranks: int = 0,
|
||||
):
|
||||
if interfaceIpPortTrio == "" and (mpi_comm is not None or torch_group is not None):
|
||||
uniq_id = None
|
||||
@@ -70,9 +71,11 @@ class CommGroup:
|
||||
else:
|
||||
raise RuntimeError("Either the interface or mpi_group need to be specified")
|
||||
self.communicator = CppCommunicator(self.bootstrap)
|
||||
self.communicator.set_ipc_domain_n_ranks(ipc_domain_n_ranks)
|
||||
self.my_rank = self.bootstrap.get_rank()
|
||||
self.nranks = self.bootstrap.get_n_ranks()
|
||||
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
|
||||
self.ipc_domain_n_ranks = self.communicator.get_ipc_domain_n_ranks()
|
||||
|
||||
def barrier(self):
|
||||
self.bootstrap.barrier()
|
||||
|
||||
Reference in New Issue
Block a user