Add GpuBuffer class (#423)

* Renamed and moved mem alloc functions into the `mscclpp::detail::`
namespace (now `mscclpp::detail::gpuCalloc*<T>()`)
* Deprecated constructor-calling mem alloc functions
(`mscclpp::makeShared*<T>()` and `mscclpp::makeUnique*<T>()`)
* Added a new `mscclpp::GpuBuffer<T>()` class that should be used in
general for allocating communication buffers
* Added a new `mscclpp.utils.GpuBuffer` Python class that inherits
`cupy.ndarray` and allocates using `mscclpp::gpuMemAlloc`
* Renamed `mscclpp::memcpyCuda*<T>()` functions into
`mscclpp::gpuMemcpy*<T>()` for name consistency
* A few fixes in NVLS memory allocation
* Tackled minor compiler warnings
This commit is contained in:
Changho Hwang
2025-01-07 18:40:01 -08:00
committed by GitHub
parent 6d26b92665
commit 34945fb107
38 changed files with 527 additions and 555 deletions

View File

@@ -6,10 +6,11 @@ import os
import struct
import subprocess
import tempfile
from typing import Any, Type
from typing import Any, Type, Union, Tuple
import cupy as cp
import numpy as np
from ._mscclpp import RawGpuBuffer
try:
import torch
@@ -36,7 +37,7 @@ class Kernel:
nblocks: int,
nthreads: int,
shared: int,
stream: Type[cp.cuda.Stream] or Type[None],
stream: Union[cp.cuda.Stream, None],
):
buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params)
buffer_size = ctypes.c_size_t(len(params))
@@ -137,6 +138,25 @@ class KernelBuilder:
self._tempdir.cleanup()
class GpuBuffer(cp.ndarray):
def __new__(
cls, shape: Union[int, Tuple[int]], dtype: cp.dtype = float, strides: Tuple[int] = None, order: str = "C"
):
# Check if `shape` is valid
if isinstance(shape, int):
shape = (shape,)
try:
shape = tuple(shape)
except TypeError:
raise ValueError("Shape must be a tuple-like or an integer.")
if any(s <= 0 for s in shape):
raise ValueError("Shape must be positive.")
# Create the buffer
buffer = RawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize)
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer.data(), buffer.bytes(), buffer), 0)
return cp.ndarray(shape, dtype=dtype, strides=strides, order=order, memptr=memptr)
def pack(*args):
res = b""
for arg in list(args):