mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Also fixes bugs in MscclppAllReduce6 Below is the performance when the algorithm is fixed to MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2. Float16: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 | | 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 | | 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 | | 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 | | 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 | | 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 | | 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 | | 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 | | 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 | | 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 | | 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 | | 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 | | 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 | | 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 | | 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 | | 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 | | 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ Float32: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 | | 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 | | 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 | | 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 | | 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 | | 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 | | 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 | | 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 | | 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 | | 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 | | 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 | | 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 | | 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 | | 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 | | 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 | | 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 | | 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
517 lines
20 KiB
Python
517 lines
20 KiB
Python
import os
|
|
import cupy as cp
|
|
import ctypes
|
|
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
|
|
import mscclpp.comm as mscclpp_comm
|
|
from mscclpp.utils import KernelBuilder, pack
|
|
|
|
|
|
IB_TRANSPORTS = [
|
|
Transport.IB0,
|
|
Transport.IB1,
|
|
Transport.IB2,
|
|
Transport.IB3,
|
|
Transport.IB4,
|
|
Transport.IB5,
|
|
Transport.IB6,
|
|
Transport.IB7,
|
|
]
|
|
|
|
|
|
def type_to_str(dtype):
|
|
if dtype == cp.float16:
|
|
return "__half"
|
|
elif dtype == cp.float32:
|
|
return "float"
|
|
elif dtype == cp.int32:
|
|
return "int"
|
|
else:
|
|
raise RuntimeError("Unknown data type")
|
|
|
|
|
|
class MscclppAllReduce1:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
memory: cp.ndarray,
|
|
read_only: int = 1,
|
|
block_size: int = 1024,
|
|
nblocks: int = 24,
|
|
):
|
|
self.group = group
|
|
self.memory = memory
|
|
remote_nghrs = list(range(self.group.nranks))
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
|
type_str = type_to_str(memory.dtype)
|
|
|
|
# create a sm_channel for each remote neighbor
|
|
self.sm_channels = self.group.make_sm_channels(self.memory, self.connections)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu",
|
|
kernel_name="allreduce1",
|
|
file_dir=file_dir,
|
|
macro_dict={"TYPE": type_str},
|
|
).get_compiled_kernel()
|
|
self.device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank:
|
|
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
|
|
|
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
|
|
|
self.set_params(nblocks, block_size, read_only)
|
|
|
|
def __call__(self, stream):
|
|
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
|
|
return self.memory
|
|
|
|
def set_params(self, nblocks, block_size, read_only):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
self.read_only = read_only
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.device_handles_cp,
|
|
self.memory,
|
|
self.group.my_rank,
|
|
self.group.nranks,
|
|
ctypes.c_size_t(self.memory.size),
|
|
self.read_only,
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
|
block_size_to_try = [256, 512, 1024]
|
|
read_only_to_try = [0, 1]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
for read_only in read_only_to_try:
|
|
self.set_params(nblocks, block_size, read_only)
|
|
yield nblocks, block_size, read_only
|
|
|
|
|
|
class MscclppAllReduce2:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
memory: cp.ndarray,
|
|
memory_out: cp.ndarray,
|
|
block_size: int = 512,
|
|
nblocks: int = 21,
|
|
):
|
|
self.group = group
|
|
self.memory = memory
|
|
self.memory_out = memory_out
|
|
remote_nghrs = list(range(self.group.nranks))
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
|
type_str = type_to_str(memory.dtype)
|
|
|
|
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
|
# create a sm_channel for each remote neighbor
|
|
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu", kernel_name="allreduce2", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
|
).get_compiled_kernel()
|
|
self.device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank:
|
|
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
|
|
|
|
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
|
|
|
self.set_params(nblocks, block_size)
|
|
|
|
def __call__(self, stream):
|
|
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
|
|
return self.memory_out
|
|
|
|
def set_params(self, nblocks, block_size):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.device_handles_cp,
|
|
self.memory,
|
|
self.scratch,
|
|
self.memory_out,
|
|
self.group.my_rank,
|
|
self.group.nranks,
|
|
ctypes.c_size_t(self.memory.size),
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [21, 42, 63, 84, 105]
|
|
block_size_to_try = [256, 512, 1024]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
self.set_params(nblocks, block_size)
|
|
yield nblocks, block_size
|
|
|
|
|
|
class MscclppAllReduce3:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
memory: cp.ndarray,
|
|
proxy_service: ProxyService,
|
|
block_size: int = 1024,
|
|
nblocks: int = 24,
|
|
):
|
|
self.group = group
|
|
self.memory = memory
|
|
remote_nghrs = list(range(self.group.nranks))
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
|
type_str = type_to_str(memory.dtype)
|
|
|
|
self.proxy_service = proxy_service
|
|
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
|
|
|
# create a sm_channel for each remote neighbor
|
|
self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch(
|
|
self.proxy_service, self.memory, self.scratch, self.connections
|
|
)
|
|
self.snd_round_proxy_chans = self.group.make_proxy_channels(self.proxy_service, self.memory, self.connections)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu", kernel_name="allreduce3", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
|
).get_compiled_kernel()
|
|
self.fst_device_handles = []
|
|
self.snd_device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank:
|
|
self.fst_device_handles.append(self.fst_round_proxy_chans[rank].device_handle().raw)
|
|
self.snd_device_handles.append(self.snd_round_proxy_chans[rank].device_handle().raw)
|
|
self.fst_device_handles_cp = cp.asarray(memoryview(b"".join(self.fst_device_handles)), dtype=cp.uint8)
|
|
self.snd_device_handles_cp = cp.asarray(memoryview(b"".join(self.snd_device_handles)), dtype=cp.uint8)
|
|
|
|
self.set_params(nblocks, block_size)
|
|
|
|
def __call__(self, stream):
|
|
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream)
|
|
return self.memory
|
|
|
|
def set_params(self, nblocks, block_size):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.fst_device_handles_cp,
|
|
self.snd_device_handles_cp,
|
|
self.memory,
|
|
self.scratch,
|
|
self.group.my_rank,
|
|
self.group.nranks,
|
|
ctypes.c_size_t(self.memory.size),
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
|
block_size_to_try = [256, 512, 1024]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
self.set_params(nblocks, block_size)
|
|
yield nblocks, block_size
|
|
|
|
|
|
class MscclppAllReduce4:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
memory: cp.ndarray,
|
|
nranks_per_node: int,
|
|
proxy_service: ProxyService,
|
|
nblocks: int = 45,
|
|
block_size: int = 512,
|
|
pipeline_depth: int = 3,
|
|
):
|
|
self.group = group
|
|
self.memory = memory
|
|
|
|
self.nranks_per_node = nranks_per_node
|
|
in_same_node = lambda rank: rank // nranks_per_node == self.group.my_rank // nranks_per_node
|
|
remote_nghrs = list(range(self.group.nranks))
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
transports = {}
|
|
for rank in remote_nghrs:
|
|
if in_same_node(rank):
|
|
transports[rank] = Transport.CudaIpc
|
|
else:
|
|
transports[rank] = IB_TRANSPORTS[rank % nranks_per_node]
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.connections = self.group.make_connection(remote_nghrs, transports)
|
|
type_str = type_to_str(memory.dtype)
|
|
|
|
self.proxy_service = proxy_service
|
|
self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype)
|
|
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
|
# create a sm_channel for each remote neighbor
|
|
self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections)
|
|
self.reduce_scatter_proxy_channels = self.group.make_proxy_channels_with_scratch(
|
|
self.proxy_service, self.memory, self.scratch, self.connections
|
|
)
|
|
self.all_gather_proxy_channels = self.group.make_proxy_channels(
|
|
self.proxy_service, self.memory, self.connections
|
|
)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu", kernel_name="allreduce4", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
|
).get_compiled_kernel()
|
|
self.sm_device_handles = []
|
|
self.reduce_sactter_proxy_device_handles = []
|
|
self.all_gather_proxy_device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank and in_same_node(rank):
|
|
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
|
|
if rank != self.group.my_rank:
|
|
self.reduce_sactter_proxy_device_handles.append(
|
|
self.reduce_scatter_proxy_channels[rank].device_handle().raw
|
|
)
|
|
self.all_gather_proxy_device_handles.append(self.all_gather_proxy_channels[rank].device_handle().raw)
|
|
|
|
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
|
self.reduce_sactter_proxy_device_handles_cp = cp.asarray(
|
|
memoryview(b"".join(self.reduce_sactter_proxy_device_handles)), dtype=cp.uint8
|
|
)
|
|
self.all_gather_proxy_device_handles_cp = cp.asarray(
|
|
memoryview(b"".join(self.all_gather_proxy_device_handles)), dtype=cp.uint8
|
|
)
|
|
|
|
self.set_params(nblocks, block_size, pipeline_depth)
|
|
|
|
def __call__(self, stream):
|
|
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
|
|
return self.memory
|
|
|
|
def set_params(self, nblocks, block_size, pipeline_depth):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
self.pipeline_depth = pipeline_depth
|
|
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.sm_device_handles_cp,
|
|
self.reduce_sactter_proxy_device_handles_cp,
|
|
self.all_gather_proxy_device_handles_cp,
|
|
self.memory,
|
|
self.scratch,
|
|
self.group.my_rank,
|
|
self.nranks_per_node,
|
|
self.group.nranks,
|
|
bytes(4), # padding for memory alignment
|
|
ctypes.c_size_t(self.memory.size),
|
|
self.pipeline_depth,
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [24, 32, 40, 45, 48, 64, 72, 90, 96, 108]
|
|
block_size_to_try = [256, 512, 1024]
|
|
pipeline_depth_to_try = [1, 2, 3, 4]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
for pipeline_depth in pipeline_depth_to_try:
|
|
self.set_params(nblocks, block_size, pipeline_depth)
|
|
yield nblocks, block_size, pipeline_depth
|
|
|
|
|
|
class MscclppAllReduce5:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
memory: cp.ndarray,
|
|
memory_out: cp.ndarray,
|
|
nranks_per_node: int,
|
|
proxy_service: ProxyService,
|
|
nblocks: int = 21,
|
|
block_size: int = 512,
|
|
):
|
|
self.group = group
|
|
self.memory = memory
|
|
self.memory_out = memory_out
|
|
|
|
self.nranks_per_node = nranks_per_node
|
|
in_same_node = lambda rank: rank // nranks_per_node == self.group.my_rank // nranks_per_node
|
|
remote_nghrs = list(range(self.group.nranks))
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
transports = {}
|
|
for rank in remote_nghrs:
|
|
if in_same_node(rank):
|
|
transports[rank] = Transport.CudaIpc
|
|
else:
|
|
transports[rank] = IB_TRANSPORTS[rank % nranks_per_node]
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.connections = self.group.make_connection(remote_nghrs, transports)
|
|
type_str = type_to_str(memory.dtype)
|
|
|
|
self.proxy_service = proxy_service
|
|
self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype)
|
|
self.put_buff = cp.zeros(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype)
|
|
same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)}
|
|
across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)}
|
|
# create a sm_channel for each remote neighbor
|
|
self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, same_node_connections)
|
|
self.proxy_channels = self.group.make_proxy_channels_with_scratch(
|
|
self.proxy_service, self.put_buff, self.scratch, across_node_connections
|
|
)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu", kernel_name="allreduce5", file_dir=file_dir, macro_dict={"TYPE": type_str}
|
|
).get_compiled_kernel()
|
|
self.sm_device_handles = []
|
|
self.proxy_device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank and in_same_node(rank):
|
|
self.sm_device_handles.append(self.sm_channels[rank].device_handle().raw)
|
|
if rank != self.group.my_rank and not in_same_node(rank):
|
|
self.proxy_device_handles.append(self.proxy_channels[rank].device_handle().raw)
|
|
|
|
self.sm_device_handles_cp = cp.asarray(memoryview(b"".join(self.sm_device_handles)), dtype=cp.uint8)
|
|
self.proxy_device_handles_cp = cp.asarray(memoryview(b"".join(self.proxy_device_handles)), dtype=cp.uint8)
|
|
|
|
self.set_params(nblocks, block_size)
|
|
|
|
def __call__(self, stream):
|
|
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream)
|
|
return self.memory_out
|
|
|
|
def set_params(self, nblocks, block_size):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.sm_device_handles_cp,
|
|
self.proxy_device_handles_cp,
|
|
self.memory,
|
|
self.scratch,
|
|
self.put_buff,
|
|
self.memory_out,
|
|
self.group.my_rank,
|
|
self.nranks_per_node,
|
|
self.group.nranks,
|
|
bytes(4), # padding for memory alignment
|
|
ctypes.c_size_t(self.memory.size),
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [21, 42, 84]
|
|
block_size_to_try = [256, 512, 1024]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
self.set_params(nblocks, block_size)
|
|
yield nblocks, block_size
|
|
|
|
|
|
class MscclppAllReduce6:
|
|
def __init__(
|
|
self,
|
|
group: mscclpp_comm.CommGroup,
|
|
nelem: int,
|
|
memory_dtype: cp.dtype,
|
|
block_size: int = 1024,
|
|
nblocks: int = 32,
|
|
):
|
|
self.group = group
|
|
datatype_size = memory_dtype().itemsize
|
|
buffer_size = nelem * datatype_size
|
|
type_str = type_to_str(memory_dtype)
|
|
all_ranks = list(range(group.nranks))
|
|
remote_nghrs = all_ranks.copy()
|
|
remote_nghrs.remove(self.group.my_rank)
|
|
|
|
self.group.barrier()
|
|
# create a connection for each remote neighbor
|
|
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
|
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
|
min_gran = self.nvls_connection.get_multicast_min_granularity()
|
|
aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran)
|
|
self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory(
|
|
aligned_buffer_size
|
|
) # just using recommended size for now
|
|
self.memory_ptr = self.nvls_mem_handle.get_device_ptr()
|
|
|
|
self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0)
|
|
self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr)
|
|
|
|
# create a sm_channel for each remote neighbor
|
|
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
|
|
file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.kernel = KernelBuilder(
|
|
file="allreduce.cu",
|
|
kernel_name="allreduce6",
|
|
file_dir=file_dir,
|
|
macro_dict={"TYPE": type_str},
|
|
).get_compiled_kernel()
|
|
self.device_handles = []
|
|
for rank in range(self.group.nranks):
|
|
if rank != self.group.my_rank:
|
|
self.device_handles.append(self.semaphores[rank].device_handle().raw)
|
|
|
|
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
|
self.nvls_handle = self.nvls_mem_handle.device_handle().raw
|
|
|
|
if self.memory.dtype != cp.float16 and self.memory.dtype != cp.float32:
|
|
raise RuntimeError("Unsupported data type")
|
|
|
|
if self.memory.dtype == cp.float16:
|
|
vector_size = 8
|
|
elif self.memory.dtype == cp.float32:
|
|
vector_size = 4
|
|
else:
|
|
vector_size = 1
|
|
self.set_params(nblocks, block_size, vector_size)
|
|
|
|
def get_memory(self):
|
|
return self.memory
|
|
|
|
def __call__(self, stream_ptr):
|
|
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
|
|
return self.memory
|
|
|
|
def set_params(self, nblocks, block_size, vector_size):
|
|
self.nblocks = nblocks
|
|
self.block_size = block_size
|
|
self.vector_size = vector_size
|
|
self.params = b""
|
|
self.params += pack(
|
|
self.device_handles_cp,
|
|
self.nvls_handle,
|
|
self.group.my_rank,
|
|
self.group.nranks,
|
|
ctypes.c_size_t(self.memory.size),
|
|
self.vector_size,
|
|
)
|
|
|
|
def auto_tune(self):
|
|
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
|
block_size_to_try = [256, 512, 1024]
|
|
if self.memory.dtype == cp.float16:
|
|
vector_size_to_try = [8, 4, 2]
|
|
elif self.memory.dtype == cp.float32:
|
|
vector_size_to_try = [4, 2, 1]
|
|
else:
|
|
vector_size_to_try = [1]
|
|
for nblocks in nblocks_to_try:
|
|
for block_size in block_size_to_try:
|
|
for vector_size in vector_size_to_try:
|
|
self.set_params(nblocks, block_size, vector_size)
|
|
yield nblocks, block_size, vector_size
|