Get rid of comm.setup()

This commit is contained in:
Olli Saarikivi
2023-08-31 17:45:58 +00:00
committed by Saeed Maleki
parent 0863e862f5
commit 8cb63a7d1a
23 changed files with 253 additions and 352 deletions

View File

@@ -40,12 +40,11 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
for r in range(world_size):
if r == rank:
continue
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory_on_setup(reg_mem, r, 0)
remote_mem = comm.recv_memory_on_setup(r, 0)
comm.send_memory(reg_mem, r, 0)
remote_mem = comm.recv_memory(r, 0)
remote_memories.append(remote_mem)
comm.setup()
connections = [conn.get() for conn in connections]

View File

@@ -35,15 +35,13 @@ def main(args):
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
else:
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
comm.setup()
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
if rank == 0:
other_reg_mem = other_reg_mem.get()

View File

@@ -21,19 +21,17 @@ extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
.def("ready", &NonblockingFuture<T>::ready)
.def("get", &NonblockingFuture<T>::get);
void def_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("std_future_") + typestr;
nb::class_<std::future<T>>(m, pyclass_name.c_str()).def("get", &std::future<T>::get);
}
void register_core(nb::module_& m) {
m.def("version", &version);
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def_prop_ro("rank", &Bootstrap::rank)
.def_prop_ro("size", &Bootstrap::size)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
@@ -45,15 +43,15 @@ void register_core(nb::module_& m) {
"recv",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
void* data = reinterpret_cast<void*>(ptr);
self->recv(data, size, peer, tag);
return self->recv(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
.def("barrier", &Bootstrap::barrier)
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"))
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
.def("recv", (std::future<std::vector<char>>(Bootstrap::*)(int, int)) & Bootstrap::recv, nb::arg("peer"),
nb::arg("tag"));
nb::class_<UniqueId>(m, "UniqueId");
@@ -149,8 +147,8 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
def_future<RegisteredMemory>(m, "RegisteredMemory");
def_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -163,14 +161,11 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
.def("tag_of", &Communicator::tagOf);
}
NB_MODULE(_mscclpp, m) {

View File

@@ -76,8 +76,7 @@ class MscclppGroup:
def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict[int, Connection]:
connections = {}
for rank in remote_ranks:
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
self.communicator.setup()
connections[rank] = self.communicator.connect(rank, 0, transport)
connections = {rank: connections[rank].get() for rank in connections}
return connections
@@ -93,9 +92,8 @@ class MscclppGroup:
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
for rank in connections:
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
self.communicator.setup()
self.communicator.send_memory(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory(rank, 0)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
@@ -108,7 +106,6 @@ class MscclppGroup:
semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
self.communicator.setup()
return semaphores
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]: