Allow semaphores and memory to be registered separately in ProxyService (#264)

This is needed in use cases where SimpleProxyChannel does not suffice.
For example, when a single semaphore is to be used for multiple tensors
or when multiple semaphores should be associated with a tensor.
This commit is contained in:
aashaka
2024-02-08 09:55:29 -08:00
committed by GitHub
parent 7c229fbdd8
commit d97fef4395

View File

@@ -220,3 +220,24 @@ class CommGroup:
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
)
return channels
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphore_ids = {}
for rank in semaphores:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank])
return channels
def register_memory_with_proxy(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, int]:
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
for rank in registered_memories:
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
return memory_ids