mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 05:31:27 +00:00
NVLS support for msccl++ executor (#375)
- Support mote datatype for multicast operation - Add new OP MULTI_LOAD_REDUCE_STORE to support NVLS - Modify allocSharedPhysicalCuda, which return std::shared_ptr<T> instead of std::shared_ptr<PhysicalCudaMemory> - Add Python support for allocSharedPhysicalCuda Test passed for `allreduce_nvls.json`
This commit is contained in:
@@ -8,6 +8,8 @@ from mscclpp import (
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
npkit,
|
||||
alloc_shared_physical_cuda,
|
||||
is_nvls_supported,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
@@ -125,6 +127,18 @@ def dtype_to_mscclpp_dtype(dtype):
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
def allocate_buffer(nelems, dtype):
|
||||
if is_nvls_supported:
|
||||
buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize)
|
||||
buffer_ptr = cp.cuda.MemoryPointer(
|
||||
cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0
|
||||
)
|
||||
buffer = cp.ndarray(nelems, dtype=dtype, memptr=buffer_ptr)
|
||||
return buffer
|
||||
else:
|
||||
return cp.zeros(nelems, dtype=dtype)
|
||||
|
||||
|
||||
def build_bufs(
|
||||
execution_plan_name: str,
|
||||
size: int,
|
||||
@@ -144,14 +158,14 @@ def build_bufs(
|
||||
nelems_input = nelems
|
||||
nelems_output = nelems
|
||||
|
||||
result_buf = cp.zeros(nelems_output, dtype=dtype)
|
||||
result_buf = allocate_buffer(nelems_output, dtype=dtype)
|
||||
if in_place:
|
||||
if "allgather" in execution_plan_name:
|
||||
input_buf = cp.split(result_buf, num_ranks)[rank]
|
||||
else:
|
||||
input_buf = result_buf
|
||||
else:
|
||||
input_buf = cp.zeros(nelems_input, dtype=dtype)
|
||||
input_buf = allocate_buffer(nelems_input, dtype=dtype)
|
||||
test_buf = cp.zeros(nelems_output, dtype=dtype)
|
||||
|
||||
return input_buf, result_buf, test_buf
|
||||
|
||||
Reference in New Issue
Block a user