mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Get rid of comm.setup()
This commit is contained in:
committed by
Saeed Maleki
parent
0863e862f5
commit
8cb63a7d1a
@@ -40,12 +40,11 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
for r in range(world_size):
|
||||
if r == rank:
|
||||
continue
|
||||
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
|
||||
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
|
||||
connections.append(conn)
|
||||
comm.send_memory_on_setup(reg_mem, r, 0)
|
||||
remote_mem = comm.recv_memory_on_setup(r, 0)
|
||||
comm.send_memory(reg_mem, r, 0)
|
||||
remote_mem = comm.recv_memory(r, 0)
|
||||
remote_memories.append(remote_mem)
|
||||
comm.setup()
|
||||
|
||||
connections = [conn.get() for conn in connections]
|
||||
|
||||
|
||||
@@ -35,15 +35,13 @@ def main(args):
|
||||
size = elements * memory.itemsize
|
||||
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
|
||||
|
||||
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
|
||||
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
|
||||
|
||||
other_reg_mem = None
|
||||
if rank == 0:
|
||||
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
|
||||
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
|
||||
else:
|
||||
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
|
||||
|
||||
comm.setup()
|
||||
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
|
||||
|
||||
if rank == 0:
|
||||
other_reg_mem = other_reg_mem.get()
|
||||
|
||||
@@ -21,19 +21,17 @@ extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
|
||||
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
|
||||
.def("ready", &NonblockingFuture<T>::ready)
|
||||
.def("get", &NonblockingFuture<T>::get);
|
||||
void def_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("std_future_") + typestr;
|
||||
nb::class_<std::future<T>>(m, pyclass_name.c_str()).def("get", &std::future<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
m.def("version", &version);
|
||||
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def_prop_ro("rank", &Bootstrap::rank)
|
||||
.def_prop_ro("size", &Bootstrap::size)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
@@ -45,15 +43,15 @@ void register_core(nb::module_& m) {
|
||||
"recv",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->recv(data, size, peer, tag);
|
||||
return self->recv(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
||||
.def("barrier", &Bootstrap::barrier)
|
||||
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"))
|
||||
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
.def("recv", (std::future<std::vector<char>>(Bootstrap::*)(int, int)) & Bootstrap::recv, nb::arg("peer"),
|
||||
nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId");
|
||||
|
||||
@@ -149,8 +147,8 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
def_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
@@ -163,14 +161,11 @@ void register_core(nb::module_& m) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("localConfig"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
.def("tag_of", &Communicator::tagOf);
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
|
||||
@@ -76,8 +76,7 @@ class MscclppGroup:
|
||||
def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict[int, Connection]:
|
||||
connections = {}
|
||||
for rank in remote_ranks:
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
|
||||
self.communicator.setup()
|
||||
connections[rank] = self.communicator.connect(rank, 0, transport)
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
@@ -93,9 +92,8 @@ class MscclppGroup:
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
future_memories = {}
|
||||
for rank in connections:
|
||||
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
|
||||
self.communicator.setup()
|
||||
self.communicator.send_memory(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank, 0)
|
||||
for rank in connections:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
@@ -108,7 +106,6 @@ class MscclppGroup:
|
||||
semaphores = {}
|
||||
for rank in connections:
|
||||
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
|
||||
self.communicator.setup()
|
||||
return semaphores
|
||||
|
||||
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]:
|
||||
|
||||
Reference in New Issue
Block a user