Fix Python bindings and tests (#690)

Minimal fix to make things work. We need a more careful look at
preventing silent fallback of nanobind when it fails to (properly)
construct a C++ STL object with mscclpp instances.
This commit is contained in:
Changho Hwang
2025-11-21 12:53:12 -08:00
committed by GitHub
parent 060c35fec6
commit 8b8593ba51
8 changed files with 73 additions and 63 deletions

View File

@@ -47,6 +47,7 @@ from ._mscclpp import (
connect_nvls_collective,
EndpointConfig,
Fifo,
Semaphore,
Host2DeviceSemaphore,
Host2HostSemaphore,
numa,
@@ -79,6 +80,7 @@ __all__ = [
"connect_nvls_collective",
"EndpointConfig",
"Fifo",
"Semaphore",
"Host2DeviceSemaphore",
"Host2HostSemaphore",
"numa",

View File

@@ -10,6 +10,7 @@ from ._mscclpp import (
Connection,
connect_nvls_collective,
EndpointConfig,
Semaphore,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
@@ -133,18 +134,14 @@ class CommGroup:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
def make_semaphore(
self,
connections: dict[int, Connection],
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
) -> dict[int, Host2HostSemaphore]:
semaphores = {}
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
future_semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
return semaphores
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
return {rank: future.get() for rank, future in future_semaphores.items()}
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
for rank in connections:
@@ -159,7 +156,7 @@ class CommGroup:
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
@@ -177,7 +174,7 @@ class CommGroup:
def make_port_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
semaphore_ids = {}
@@ -210,7 +207,7 @@ class CommGroup:
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
memory_ids = {}
semaphore_ids = {}
@@ -229,7 +226,7 @@ class CommGroup:
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
semaphore_ids = {}
for rank in semaphores:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])