Files
mscclpp/python/mscclpp_benchmark/nccl_op.py
Changho Hwang 7bd66a938c Robust correctness test (#221)
Co-authored-by: Aashaka Shah <aashaka96@gmail.com>
2023-11-22 12:06:50 +08:00

24 lines
839 B
Python

import cupy.cuda.nccl as nccl
from mpi4py import MPI
import cupy as cp
class NcclAllReduce:
def __init__(self, nccl_comm: nccl.NcclCommunicator, memory: cp.ndarray):
self.nccl_comm = nccl_comm
self.memory = memory
if memory.dtype == cp.float32:
self.nccl_dtype = nccl.NCCL_FLOAT32
elif memory.dtype == cp.float16:
self.nccl_dtype = nccl.NCCL_FLOAT16
elif memory.dtype == cp.int32:
self.nccl_dtype = nccl.NCCL_INT32
else:
raise RuntimeError("Make sure that the data type is mapped to the correct NCCL data type")
def __call__(self, stream_ptr):
self.nccl_comm.allReduce(
self.memory.data.ptr, self.memory.data.ptr, self.memory.size, self.nccl_dtype, nccl.NCCL_SUM, stream_ptr
)
return self.memory