mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-20 21:09:13 +00:00
Fix Python bindings and tests (#690)
Minimal fix to make things work. We need a more careful look at preventing silent fallback of nanobind when it fails to (properly) construct a C++ STL object with mscclpp instances.
This commit is contained in:
@@ -47,6 +47,7 @@ from ._mscclpp import (
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Semaphore,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
numa,
|
||||
@@ -79,6 +80,7 @@ __all__ = [
|
||||
"connect_nvls_collective",
|
||||
"EndpointConfig",
|
||||
"Fifo",
|
||||
"Semaphore",
|
||||
"Host2DeviceSemaphore",
|
||||
"Host2HostSemaphore",
|
||||
"numa",
|
||||
|
||||
@@ -10,6 +10,7 @@ from ._mscclpp import (
|
||||
Connection,
|
||||
connect_nvls_collective,
|
||||
EndpointConfig,
|
||||
Semaphore,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
@@ -133,18 +134,14 @@ class CommGroup:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
def make_semaphore(
|
||||
self,
|
||||
connections: dict[int, Connection],
|
||||
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
|
||||
) -> dict[int, Host2HostSemaphore]:
|
||||
semaphores = {}
|
||||
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
|
||||
future_semaphores = {}
|
||||
for rank in connections:
|
||||
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
|
||||
return semaphores
|
||||
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
|
||||
return {rank: future.get() for rank, future in future_semaphores.items()}
|
||||
|
||||
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
channels = {}
|
||||
for rank in connections:
|
||||
@@ -159,7 +156,7 @@ class CommGroup:
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
@@ -177,7 +174,7 @@ class CommGroup:
|
||||
def make_port_channels(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
semaphore_ids = {}
|
||||
@@ -210,7 +207,7 @@ class CommGroup:
|
||||
)
|
||||
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
|
||||
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
memory_ids = {}
|
||||
semaphore_ids = {}
|
||||
@@ -229,7 +226,7 @@ class CommGroup:
|
||||
def register_semaphore_with_proxy(
|
||||
self, proxy_service: ProxyService, connections: dict[int, Connection]
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphores = self.make_semaphores(connections)
|
||||
semaphore_ids = {}
|
||||
for rank in semaphores:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
|
||||
Reference in New Issue
Block a user