mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Use smart pointer for IB structure (#585)
Change to use smart pointer for IB structure. Registered memory will own ibMr, ibCtx will not held the reference - Use smart pointer for IbQp and IbMr - Update memoryChannel API, keep localRegisteredMemory - Close fd when registedMemory released --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Type
|
||||
from typing import Tuple, Type
|
||||
|
||||
import cupy as cp
|
||||
from ._mscclpp import (
|
||||
@@ -109,18 +109,7 @@ class CommGroup:
|
||||
def register_tensor_with_connections(
|
||||
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
|
||||
) -> dict[int, RegisteredMemory]:
|
||||
transport_flags = TransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
data_ptr = (
|
||||
tensor.data.ptr
|
||||
if isinstance(tensor, cp.ndarray)
|
||||
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
|
||||
)
|
||||
tensor_size = (
|
||||
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
|
||||
)
|
||||
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
|
||||
local_reg_memory = self.register_local_memory(tensor, connections)
|
||||
all_registered_memories = {}
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
future_memories = {}
|
||||
@@ -131,6 +120,19 @@ class CommGroup:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
def _register_memory_with_connections(
|
||||
self, memory: RegisteredMemory, connections: dict[int, Connection]
|
||||
) -> dict[int, RegisteredMemory]:
|
||||
all_registered_memories = {}
|
||||
all_registered_memories[self.my_rank] = memory
|
||||
future_memories = {}
|
||||
for rank in connections:
|
||||
self.communicator.send_memory(memory, rank)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank)
|
||||
for rank in connections:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
|
||||
def make_semaphore(
|
||||
self,
|
||||
connections: dict[int, Connection],
|
||||
@@ -145,31 +147,36 @@ class CommGroup:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
for rank in connections:
|
||||
channels[rank] = MemoryChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
|
||||
channels[rank] = MemoryChannel(
|
||||
semaphores[rank], registered_memories[rank], registered_memories[self.my_rank]
|
||||
)
|
||||
return channels
|
||||
|
||||
def make_memory_channels_with_scratch(
|
||||
self,
|
||||
tensor: cp.ndarray,
|
||||
scratchTensor: cp.ndarray,
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, MemoryChannel]:
|
||||
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
channels = {}
|
||||
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
|
||||
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
|
||||
tensor_size = (
|
||||
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
|
||||
)
|
||||
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, TransportFlags())
|
||||
scratch_data_ptr = registeredScratchBuffer.data()
|
||||
for rank in connections:
|
||||
channels[rank] = MemoryChannel(
|
||||
semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr
|
||||
semaphores[rank], registered_memories[rank], local_registered_memory, scratch_data_ptr
|
||||
)
|
||||
return channels
|
||||
|
||||
def make_port_channels(
|
||||
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
|
||||
) -> dict[int, MemoryChannel]:
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(tensor, connections)
|
||||
memory_ids = {}
|
||||
@@ -187,9 +194,9 @@ class CommGroup:
|
||||
self,
|
||||
proxy_service: ProxyService,
|
||||
tensor: cp.ndarray,
|
||||
scratchTensor: cp.ndarray,
|
||||
registeredScratchBuffer: RegisteredMemory,
|
||||
connections: dict[int, Connection],
|
||||
) -> dict[int, MemoryChannel]:
|
||||
) -> dict[int, PortChannel]:
|
||||
transport_flags = TransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
@@ -204,7 +211,7 @@ class CommGroup:
|
||||
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
|
||||
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
|
||||
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
|
||||
memory_ids = {}
|
||||
semaphore_ids = {}
|
||||
for rank in registered_memories:
|
||||
@@ -221,7 +228,7 @@ class CommGroup:
|
||||
|
||||
def register_semaphore_with_proxy(
|
||||
self, proxy_service: ProxyService, connections: dict[int, Connection]
|
||||
) -> dict[int, MemoryChannel]:
|
||||
) -> dict[int, PortChannel]:
|
||||
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
semaphore_ids = {}
|
||||
for rank in semaphores:
|
||||
@@ -239,3 +246,17 @@ class CommGroup:
|
||||
for rank in registered_memories:
|
||||
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
|
||||
return memory_ids
|
||||
|
||||
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> RegisteredMemory:
|
||||
transport_flags = TransportFlags()
|
||||
for rank in connections:
|
||||
transport_flags |= connections[rank].transport()
|
||||
data_ptr = (
|
||||
tensor.data.ptr
|
||||
if isinstance(tensor, cp.ndarray)
|
||||
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
|
||||
)
|
||||
tensor_size = (
|
||||
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
|
||||
)
|
||||
return self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
|
||||
|
||||
@@ -126,7 +126,7 @@ void register_core(nb::module_& m) {
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
.def("size", &RegisteredMemory::size)
|
||||
.def("transports", &RegisteredMemory::transports)
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
|
||||
@@ -28,13 +28,11 @@ void register_memory_channel(nb::module_& m) {
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst,
|
||||
uintptr_t src) { new (memoryChannel) MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src)); })
|
||||
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, uintptr_t src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel)
|
||||
MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src), reinterpret_cast<void*>(packet_buffer));
|
||||
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
})
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user