mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Add GpuBuffer class (#423)
* Renamed and moved mem alloc functions into the `mscclpp::detail::` namespace (now `mscclpp::detail::gpuCalloc*<T>()`) * Deprecated constructor-calling mem alloc functions (`mscclpp::makeShared*<T>()` and `mscclpp::makeUnique*<T>()`) * Added a new `mscclpp::GpuBuffer<T>()` class that should be used in general for allocating communication buffers * Added a new `mscclpp.utils.GpuBuffer` Python class that inherits `cupy.ndarray` and allocates using `mscclpp::gpuMemAlloc` * Renamed `mscclpp::memcpyCuda*<T>()` functions into `mscclpp::gpuMemcpy*<T>()` for name consistency * A few fixes in NVLS memory allocation * Tackled minor compiler warnings
This commit is contained in:
@@ -24,9 +24,9 @@ from ._mscclpp import (
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
RawGpuBuffer,
|
||||
version,
|
||||
is_nvls_supported,
|
||||
alloc_shared_physical_cuda,
|
||||
npkit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
// #include <memory>
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
class PyCudaMemory {
|
||||
public:
|
||||
PyCudaMemory(size_t size) : size_(size) { ptr_ = allocSharedPhysicalCuda<char>(size); }
|
||||
|
||||
uintptr_t getPtr() const { return (uintptr_t)(ptr_.get()); }
|
||||
size_t size() const { return size_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<char> ptr_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
void register_gpu_utils(nb::module_& m) {
|
||||
nb::class_<PyCudaMemory>(m, "PyCudaMemory")
|
||||
.def(nb::init<size_t>(), nb::arg("size"))
|
||||
.def("get_ptr", &PyCudaMemory::getPtr, "Get the raw pointer")
|
||||
.def("size", &PyCudaMemory::size, "Get the size of the allocated memory");
|
||||
m.def(
|
||||
"alloc_shared_physical_cuda", [](size_t size) { return std::make_shared<PyCudaMemory>(size); }, nb::arg("size"));
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
|
||||
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
|
||||
.def(nb::init<size_t>(), nb::arg("nelems"))
|
||||
.def("nelems", &GpuBuffer<char>::nelems)
|
||||
.def("bytes", &GpuBuffer<char>::bytes)
|
||||
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); });
|
||||
}
|
||||
|
||||
@@ -34,5 +34,5 @@ void register_nvls(nb::module_& m) {
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"),
|
||||
nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize);
|
||||
nb::arg("bufferSize"));
|
||||
}
|
||||
|
||||
@@ -6,10 +6,11 @@ import os
|
||||
import struct
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Type
|
||||
from typing import Any, Type, Union, Tuple
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
from ._mscclpp import RawGpuBuffer
|
||||
|
||||
try:
|
||||
import torch
|
||||
@@ -36,7 +37,7 @@ class Kernel:
|
||||
nblocks: int,
|
||||
nthreads: int,
|
||||
shared: int,
|
||||
stream: Type[cp.cuda.Stream] or Type[None],
|
||||
stream: Union[cp.cuda.Stream, None],
|
||||
):
|
||||
buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params)
|
||||
buffer_size = ctypes.c_size_t(len(params))
|
||||
@@ -137,6 +138,25 @@ class KernelBuilder:
|
||||
self._tempdir.cleanup()
|
||||
|
||||
|
||||
class GpuBuffer(cp.ndarray):
|
||||
def __new__(
|
||||
cls, shape: Union[int, Tuple[int]], dtype: cp.dtype = float, strides: Tuple[int] = None, order: str = "C"
|
||||
):
|
||||
# Check if `shape` is valid
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
try:
|
||||
shape = tuple(shape)
|
||||
except TypeError:
|
||||
raise ValueError("Shape must be a tuple-like or an integer.")
|
||||
if any(s <= 0 for s in shape):
|
||||
raise ValueError("Shape must be positive.")
|
||||
# Create the buffer
|
||||
buffer = RawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
|
||||
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer.data(), buffer.bytes(), buffer), 0)
|
||||
return cp.ndarray(shape, dtype=dtype, strides=strides, order=order, memptr=memptr)
|
||||
|
||||
|
||||
def pack(*args):
|
||||
res = b""
|
||||
for arg in list(args):
|
||||
|
||||
@@ -20,5 +20,4 @@ void register_utils(nb::module_& m) {
|
||||
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
|
||||
|
||||
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ from mpi4py import MPI
|
||||
import cupy.cuda.nccl as nccl
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp import ProxyService, is_nvls_supported
|
||||
from mscclpp.utils import GpuBuffer
|
||||
from prettytable import PrettyTable
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
@@ -162,8 +163,8 @@ def find_best_config(mscclpp_call, niter):
|
||||
def run_benchmark(
|
||||
mscclpp_group: mscclpp_comm.CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int
|
||||
):
|
||||
memory = cp.zeros(nelem, dtype=data_type)
|
||||
memory_out = cp.zeros(nelem, dtype=data_type)
|
||||
memory = GpuBuffer(nelem, dtype=data_type)
|
||||
memory_out = GpuBuffer(nelem, dtype=data_type)
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
|
||||
proxy_service = ProxyService()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import cupy as cp
|
||||
import ctypes
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore, alloc_shared_physical_cuda
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
@@ -115,7 +115,7 @@ class MscclppAllReduce2:
|
||||
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -179,7 +179,7 @@ class MscclppAllReduce3:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch(
|
||||
@@ -259,7 +259,7 @@ class MscclppAllReduce4:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections)
|
||||
@@ -362,8 +362,8 @@ class MscclppAllReduce5:
|
||||
type_str = type_to_str(memory.dtype)
|
||||
|
||||
self.proxy_service = proxy_service
|
||||
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.put_buff = cp.zeros(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
||||
self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype)
|
||||
self.put_buff = GpuBuffer(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
||||
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
||||
across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)}
|
||||
# create a sm_channel for each remote neighbor
|
||||
@@ -441,18 +441,10 @@ class MscclppAllReduce6:
|
||||
# create a connection for each remote neighbor
|
||||
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
min_gran = self.nvls_connection.get_multicast_min_granularity()
|
||||
aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran)
|
||||
buffer_raw = alloc_shared_physical_cuda(aligned_buffer_size)
|
||||
self.memory = GpuBuffer(nelem, memory_dtype)
|
||||
self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory(
|
||||
buffer_raw.get_ptr(), aligned_buffer_size
|
||||
) # just using recommended size for now
|
||||
self.memory_ptr = self.nvls_mem_handle.get_device_ptr()
|
||||
|
||||
self.cp_memory_ptr = cp.cuda.MemoryPointer(
|
||||
cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, buffer_raw), 0
|
||||
self.memory.data.ptr, self.memory.data.mem.size
|
||||
)
|
||||
self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
|
||||
|
||||
@@ -8,11 +8,9 @@ from mscclpp import (
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
npkit,
|
||||
alloc_shared_physical_cuda,
|
||||
is_nvls_supported,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
import os
|
||||
import struct
|
||||
|
||||
@@ -129,18 +127,6 @@ def dtype_to_mscclpp_dtype(dtype):
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
def allocate_buffer(nelems, dtype):
|
||||
if is_nvls_supported():
|
||||
buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize)
|
||||
buffer_ptr = cp.cuda.MemoryPointer(
|
||||
cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0
|
||||
)
|
||||
buffer = cp.ndarray(nelems, dtype=dtype, memptr=buffer_ptr)
|
||||
return buffer
|
||||
else:
|
||||
return cp.zeros(nelems, dtype=dtype)
|
||||
|
||||
|
||||
def build_bufs(
|
||||
collective: str,
|
||||
size: int,
|
||||
@@ -160,14 +146,14 @@ def build_bufs(
|
||||
nelems_input = nelems
|
||||
nelems_output = nelems
|
||||
|
||||
result_buf = allocate_buffer(nelems_output, dtype=dtype)
|
||||
result_buf = GpuBuffer(nelems_output, dtype=dtype)
|
||||
if in_place:
|
||||
if "allgather" in collective:
|
||||
input_buf = cp.split(result_buf, num_ranks)[rank]
|
||||
else:
|
||||
input_buf = result_buf
|
||||
else:
|
||||
input_buf = allocate_buffer(nelems_input, dtype=dtype)
|
||||
input_buf = GpuBuffer(nelems_input, dtype=dtype)
|
||||
test_buf = cp.zeros(nelems_output, dtype=dtype)
|
||||
|
||||
return input_buf, result_buf, test_buf
|
||||
|
||||
@@ -27,7 +27,7 @@ from mscclpp import (
|
||||
npkit,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
from mscclpp.utils import KernelBuilder, GpuBuffer, pack
|
||||
from ._cpp import _ext
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
|
||||
@@ -156,12 +156,26 @@ def test_group_with_connections(mpi_group: MpiGroup, transport: str):
|
||||
create_group_and_connection(mpi_group, transport)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(1)
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [0, 10, 15, 20]])
|
||||
@pytest.mark.parametrize("dtype", [cp.float32, cp.float16])
|
||||
def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype):
|
||||
memory = GpuBuffer(nelem, dtype=dtype)
|
||||
assert memory.shape == (nelem,)
|
||||
assert memory.dtype == dtype
|
||||
assert memory.itemsize == cp.dtype(dtype).itemsize
|
||||
assert memory.nbytes == nelem * cp.dtype(dtype).itemsize
|
||||
assert memory.data.ptr != 0
|
||||
assert memory.data.mem.ptr != 0
|
||||
assert memory.data.mem.size >= nelem * cp.dtype(dtype).itemsize
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
sizePerRank = nelemPerRank * memory.itemsize
|
||||
memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1
|
||||
@@ -436,13 +450,12 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
if use_packet:
|
||||
scratch = cp.zeros(nelem * 2, dtype=cp.int32)
|
||||
scratch = GpuBuffer(nelem * 2, dtype=cp.int32)
|
||||
else:
|
||||
scratch = None
|
||||
nelemPerRank = nelem // group.nranks
|
||||
nelemPerRank * memory.itemsize
|
||||
memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1
|
||||
memory_expected = cp.zeros_like(memory)
|
||||
for rank in range(group.nranks):
|
||||
@@ -484,7 +497,7 @@ def test_fifo(
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
nelemPerRank * memory.itemsize
|
||||
memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1
|
||||
@@ -534,11 +547,11 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
memory = GpuBuffer(nelem, dtype=cp.int32)
|
||||
if use_packet:
|
||||
scratch = cp.zeros(nelem * 2, dtype=cp.int32)
|
||||
scratch = GpuBuffer(nelem * 2, dtype=cp.int32)
|
||||
else:
|
||||
scratch = cp.zeros(1, dtype=cp.int32) # just so that we can pass a valid ptr
|
||||
scratch = GpuBuffer(1, dtype=cp.int32) # just so that we can pass a valid ptr
|
||||
nelemPerRank = nelem // group.nranks
|
||||
nelemPerRank * memory.itemsize
|
||||
memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1
|
||||
|
||||
Reference in New Issue
Block a user