Get rid of comm.setup()

This commit is contained in:
Olli Saarikivi
2023-08-31 17:45:58 +00:00
committed by Saeed Maleki
parent 0863e862f5
commit 8cb63a7d1a
23 changed files with 253 additions and 352 deletions

View File

@@ -76,8 +76,7 @@ class MscclppGroup:
def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict[int, Connection]:
connections = {}
for rank in remote_ranks:
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
self.communicator.setup()
connections[rank] = self.communicator.connect(rank, 0, transport)
connections = {rank: connections[rank].get() for rank in connections}
return connections
@@ -93,9 +92,8 @@ class MscclppGroup:
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
for rank in connections:
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
self.communicator.setup()
self.communicator.send_memory(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory(rank, 0)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
@@ -108,7 +106,6 @@ class MscclppGroup:
semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
self.communicator.setup()
return semaphores
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]: