Use smart pointer for IB structure (#585)

Change to use smart pointer for IB structure. Registered memory will own
ibMr, ibCtx will not held the reference
- Use smart pointer for IbQp and IbMr
- Update memoryChannel API, keep localRegisteredMemory
- Close fd when registedMemory released

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Binyang Li
2025-08-06 10:01:58 -07:00
committed by GitHub
parent d55ac96f5e
commit 4f6f23dae3
23 changed files with 175 additions and 118 deletions

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from __future__ import annotations
from typing import Type
from typing import Tuple, Type
import cupy as cp
from ._mscclpp import (
@@ -109,18 +109,7 @@ class CommGroup:
def register_tensor_with_connections(
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
local_reg_memory = self.register_local_memory(tensor, connections)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
@@ -131,6 +120,19 @@ class CommGroup:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
def _register_memory_with_connections(
self, memory: RegisteredMemory, connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
all_registered_memories = {}
all_registered_memories[self.my_rank] = memory
future_memories = {}
for rank in connections:
self.communicator.send_memory(memory, rank)
future_memories[rank] = self.communicator.recv_memory(rank)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
def make_semaphore(
self,
connections: dict[int, Connection],
@@ -145,31 +147,36 @@ class CommGroup:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = MemoryChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
channels[rank] = MemoryChannel(
semaphores[rank], registered_memories[rank], registered_memories[self.my_rank]
)
return channels
def make_memory_channels_with_scratch(
self,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, TransportFlags())
scratch_data_ptr = registeredScratchBuffer.data()
for rank in connections:
channels[rank] = MemoryChannel(
semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr
semaphores[rank], registered_memories[rank], local_registered_memory, scratch_data_ptr
)
return channels
def make_port_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, MemoryChannel]:
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
@@ -187,9 +194,9 @@ class CommGroup:
self,
proxy_service: ProxyService,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, MemoryChannel]:
) -> dict[int, PortChannel]:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
@@ -204,7 +211,7 @@ class CommGroup:
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
memory_ids = {}
semaphore_ids = {}
for rank in registered_memories:
@@ -221,7 +228,7 @@ class CommGroup:
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, MemoryChannel]:
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphore_ids = {}
for rank in semaphores:
@@ -239,3 +246,17 @@ class CommGroup:
for rank in registered_memories:
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
return memory_ids
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> RegisteredMemory:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
return self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

View File

@@ -126,7 +126,7 @@ void register_core(nb::module_& m) {
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", &RegisteredMemory::data)
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
.def("size", &RegisteredMemory::size)
.def("transports", &RegisteredMemory::transports)
.def("serialize", &RegisteredMemory::serialize)

View File

@@ -28,13 +28,11 @@ void register_memory_channel(nb::module_& m) {
.def(nb::init<>())
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst,
uintptr_t src) { new (memoryChannel) MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src)); })
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, uintptr_t src, uintptr_t packet_buffer) {
new (memoryChannel)
MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src), reinterpret_cast<void*>(packet_buffer));
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
})
.def("device_handle", &MemoryChannel::deviceHandle);