mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Add GpuBuffer class (#423)
* Renamed and moved mem alloc functions into the `mscclpp::detail::` namespace (now `mscclpp::detail::gpuCalloc*<T>()`) * Deprecated constructor-calling mem alloc functions (`mscclpp::makeShared*<T>()` and `mscclpp::makeUnique*<T>()`) * Added a new `mscclpp::GpuBuffer<T>()` class that should be used in general for allocating communication buffers * Added a new `mscclpp.utils.GpuBuffer` Python class that inherits `cupy.ndarray` and allocates using `mscclpp::gpuMemAlloc` * Renamed `mscclpp::memcpyCuda*<T>()` functions into `mscclpp::gpuMemcpy*<T>()` for name consistency * A few fixes in NVLS memory allocation * Tackled minor compiler warnings
This commit is contained in:
@@ -15,6 +15,7 @@ from mpi4py import MPI
|
||||
import cupy.cuda.nccl as nccl
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp import ProxyService, is_nvls_supported
|
||||
from mscclpp.utils import GpuBuffer
|
||||
from prettytable import PrettyTable
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
@@ -162,8 +163,8 @@ def find_best_config(mscclpp_call, niter):
|
||||
def run_benchmark(
|
||||
mscclpp_group: mscclpp_comm.CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int
|
||||
):
|
||||
memory = cp.zeros(nelem, dtype=data_type)
|
||||
memory_out = cp.zeros(nelem, dtype=data_type)
|
||||
memory = GpuBuffer(nelem, dtype=data_type)
|
||||
memory_out = GpuBuffer(nelem, dtype=data_type)
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
|
||||
proxy_service = ProxyService()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import cupy as cp
|
||||
import ctypes
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore, alloc_shared_physical_cuda
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
@@ -115,7 +115,7 @@ class MscclppAllReduce2:
|
||||
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -179,7 +179,7 @@ class MscclppAllReduce3:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch(
|
||||
@@ -259,7 +259,7 @@ class MscclppAllReduce4:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections)
|
||||
@@ -362,8 +362,8 @@ class MscclppAllReduce5:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.put_buff = cp.zeros(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.put_buff = GpuBuffer(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
@@ -441,18 +441,10 @@ class MscclppAllReduce6:
|
||||
# create a connection for each remote neighbor
|
||||
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
min_gran = self.nvls_connection.get_multicast_min_granularity()
|
||||
aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran)
|
||||
buffer_raw = alloc_shared_physical_cuda(aligned_buffer_size)
|
||||
self.memory = GpuBuffer(nelem, memory_dtype)
|
||||
self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory(
|
||||
buffer_raw.get_ptr(), aligned_buffer_size
|
||||
) # just using recommended size for now
|
||||
self.memory_ptr = self.nvls_mem_handle.get_device_ptr()
|
||||
|
||||
self.cp_memory_ptr = cp.cuda.MemoryPointer(
|
||||
cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, buffer_raw), 0
|
||||
self.memory.data.ptr, self.memory.data.mem.size
|
||||
)
|
||||
self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
|
||||
|
||||
Reference in New Issue
Block a user