From d97fef4395aca1e61678baa33a7d9dda9b514cd0 Mon Sep 17 00:00:00 2001 From: aashaka Date: Thu, 8 Feb 2024 09:55:29 -0800 Subject: [PATCH] Allow semaphores and memory to be registered separately in ProxyService (#264) This is needed in use cases where SimpleProxyChannel does not suffice. For example, when a single semaphore is to be used for multiple tensors or when multiple semaphores should be associated with a tensor. --- python/mscclpp/comm.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index 9b7d5e16..b4f11ae4 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -220,3 +220,24 @@ class CommGroup: proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank] ) return channels + + def register_semaphore_with_proxy( + self, proxy_service: ProxyService, connections: dict[int, Connection] + ) -> dict[int, SmChannel]: + semaphores = self.make_semaphore(connections, Host2DeviceSemaphore) + semaphore_ids = {} + for rank in semaphores: + semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank]) + channels = {} + for rank in semaphores: + channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank]) + return channels + + def register_memory_with_proxy( + self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection] + ) -> dict[int, int]: + registered_memories = self.register_tensor_with_connections(tensor, connections) + memory_ids = {} + for rank in registered_memories: + memory_ids[rank] = proxy_service.add_memory(registered_memories[rank]) + return memory_ids