NVLS support. (#250)

Co-authored-by: Saeed Maleki <saemal@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Saeed Maleki
2024-02-04 20:46:10 -08:00
committed by GitHub
parent 4eb0a08b8c
commit 91d592dcc0
22 changed files with 1172 additions and 56 deletions

View File

@@ -6,6 +6,7 @@ import os as _os
from ._mscclpp import (
Communicator,
Connection,
EndpointConfig,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
@@ -19,6 +20,7 @@ from ._mscclpp import (
Transport,
TransportFlags,
version,
is_nvls_supported,
)
__version__ = version()

View File

@@ -8,6 +8,7 @@ import cupy as cp
from ._mscclpp import (
Communicator,
Connection,
EndpointConfig,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
@@ -79,15 +80,21 @@ class CommGroup:
assert False # only 8 IBs are supported
def make_connection(
self, remote_ranks: list[int], transports: Transport | dict[int, Transport]
self,
all_ranks: list[int],
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
if endpoints.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
connections = {}
for rank in remote_ranks:
if type(transports) is dict:
transport = transports[rank]
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
transport = transports
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
endpoint = endpoints
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections = {rank: connections[rank].get() for rank in connections}
return connections

View File

@@ -6,6 +6,7 @@
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/core.hpp>
@@ -72,6 +73,7 @@ void register_core(nb::module_& m) {
nb::enum_<Transport>(m, "Transport")
.value("Unknown", Transport::Unknown)
.value("CudaIpc", Transport::CudaIpc)
.value("Nvls", Transport::Nvls)
.value("IB0", Transport::IB0)
.value("IB1", Transport::IB1)
.value("IB2", Transport::IB2)
@@ -124,6 +126,24 @@ void register_core(nb::module_& m) {
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
@@ -132,6 +152,7 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
@@ -168,6 +189,7 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("connct_nvls_collective", &Communicator::connctNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);

View File

@@ -20,4 +20,5 @@ void register_utils(nb::module_& m) {
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
m.def("is_nvls_supported", &isNvlsSupported);
}

View File

@@ -4,6 +4,7 @@
#include <cuda_fp16.h>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/proxy_channel_device.hpp>
#include <mscclpp/sm_channel_device.hpp>
@@ -775,3 +776,57 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
globalFlag += 1;
}
}
// -------------------------------------------
// AllReduce6
// NVLS
// -------------------------------------------
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
size_t nelem) {
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
nvlsPtrs.multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
}
#endif

View File

@@ -2,12 +2,19 @@
# Licensed under the MIT license.
import cupy as cp
from mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5
from mscclpp_op import (
MscclppAllReduce1,
MscclppAllReduce2,
MscclppAllReduce3,
MscclppAllReduce4,
MscclppAllReduce5,
MscclppAllReduce6,
)
from nccl_op import NcclAllReduce
from mpi4py import MPI
import cupy.cuda.nccl as nccl
import mscclpp.comm as mscclpp_comm
from mscclpp import ProxyService
from mscclpp import ProxyService, is_nvls_supported
from prettytable import PrettyTable
import netifaces as ni
@@ -121,6 +128,21 @@ def bench_time(niter: int, func):
return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
def find_best_algo(mscclpp_algos, niter):
assert len(mscclpp_algos) > 0
best_time = 10000000.0
best_algo = None
for algo in mscclpp_algos:
config, cur_time = find_best_config(algo, niter)
if cur_time < best_time:
best_time = cur_time
best_algo = algo
algo.set_params(*config)
if MPI.COMM_WORLD.rank == 0:
print(best_algo, end="", flush=True)
return best_algo
def find_best_config(mscclpp_call, niter):
best_time = 10000000.0
for config in mscclpp_call.auto_tune():
@@ -133,7 +155,7 @@ def find_best_config(mscclpp_call, niter):
best_config = MPI.COMM_WORLD.bcast(best_config, root=0)
if MPI.COMM_WORLD.rank == 0:
print(best_config, end="", flush=True)
return best_config
return best_config, best_time
def run_benchmark(
@@ -145,26 +167,27 @@ def run_benchmark(
proxy_service = None
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
proxy_service = ProxyService()
if memory.nbytes < 2**20:
mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out)
elif memory.nbytes < 2**29:
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory)
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service)
proxy_service.start_proxy()
mscclpp_algos = [
MscclppAllReduce1(mscclpp_group, memory),
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
]
if is_nvls_supported():
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
else:
if memory.nbytes < 2**22:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)]
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)]
best_config = find_best_config(mscclpp_call, 20)
mscclpp_call.set_params(*best_config)
proxy_service.start_proxy()
MPI.COMM_WORLD.barrier()
mscclpp_call = find_best_algo(mscclpp_algos, 20)
if isinstance(mscclpp_call, MscclppAllReduce6):
memory = mscclpp_call.get_memory()
nccl_call = NcclAllReduce(nccl_op, memory)
@@ -177,13 +200,8 @@ def run_benchmark(
nccl_algBw = memory_nbytes / nccl_time / 1e3
nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL"
if (
isinstance(mscclpp_call, MscclppAllReduce3)
or isinstance(mscclpp_call, MscclppAllReduce5)
or isinstance(mscclpp_call, MscclppAllReduce4)
):
MPI.COMM_WORLD.barrier()
proxy_service.stop_proxy()
MPI.COMM_WORLD.barrier()
proxy_service.stop_proxy()
speed_up = nccl_time / mscclpp_time
if MPI.COMM_WORLD.rank == 0:
@@ -247,7 +265,8 @@ if __name__ == "__main__":
mscclpp_algbw = []
nccl_algbw = []
speed_ups = []
for i in range(10, 29):
end_range = 28 if is_nvls_supported() else 29
for i in range(10, end_range):
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
nelems = 2**i
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:

View File

@@ -1,7 +1,7 @@
import os
import cupy as cp
import ctypes
from mscclpp import Transport, ProxyService
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, pack
@@ -418,3 +418,82 @@ class MscclppAllReduce5:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size
class MscclppAllReduce6:
def __init__(
self,
group: mscclpp_comm.CommGroup,
nelem: int,
memory_dtype: cp.dtype,
block_size: int = 1024,
nblocks: int = 32,
):
self.group = group
datatype_size = memory_dtype().itemsize
buffer_size = nelem * datatype_size
type_str = type_to_str(memory_dtype)
all_ranks = list(range(group.nranks))
remote_nghrs = all_ranks.copy()
remote_nghrs.remove(self.group.my_rank)
self.group.barrier()
# create a connection for each remote neighbor
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
min_gran = self.nvls_connection.get_multicast_min_granularity()
aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran)
self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory(
aligned_buffer_size
) # just using recommended size for now
self.memory_ptr = self.nvls_mem_handle.get_device_ptr()
self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0)
self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr)
# create a sm_channel for each remote neighbor
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu",
kernel_name="allreduce6",
file_dir=file_dir,
macro_dict={"TYPE": type_str},
).get_compiled_kernel()
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.semaphores[rank].device_handle().raw)
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
self.nvls_handle = self.nvls_mem_handle.device_handle().raw
self.set_params(nblocks, block_size)
def get_memory(self):
return self.memory
def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
return self.memory
def set_params(self, nblocks, block_size):
self.nblocks = nblocks
self.block_size = block_size
self.params = b""
self.params += pack(
self.device_handles_cp,
self.nvls_handle,
self.memory,
self.group.my_rank,
self.group.nranks,
ctypes.c_size_t(self.memory.size),
)
def auto_tune(self):
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
block_size_to_try = [256, 512, 1024]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size

66
python/test/nvls_test.cu Normal file
View File

@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/poll_device.hpp>
#include <mscclpp/semaphore_device.hpp>
__device__ mscclpp::DeviceSyncer deviceSyncer;
extern "C" __global__ void __launch_bounds__(1024, 1)
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
int nelem = nbytes / sizeof(float);
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
dev_ptr[idx] = my_rank;
}
deviceSyncer.sync(gridDim.x);
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks && tid != my_rank) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
nvlsPtrs.multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks && tid != my_rank) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
if (dev_ptr[idx] != ((nranks * (nranks - 1)) / 2)) {
__assert_fail("dev_ptr[idx] != nranks", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}
}
}

View File

@@ -12,6 +12,7 @@ import netifaces as ni
import pytest
from mscclpp import (
EndpointConfig,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
@@ -19,6 +20,7 @@ from mscclpp import (
SmDevice2DeviceSemaphore,
TcpBootstrap,
Transport,
is_nvls_supported,
)
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, pack
@@ -117,13 +119,15 @@ def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
mpi_group.comm.barrier()
def create_and_connect(mpi_group: MpiGroup, transport: str):
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
def create_connection(group: mscclpp_comm.CommGroup, transport: str):
if transport == "NVLS":
all_ranks = list(range(group.nranks))
tran = Transport.Nvls
connection = group.make_connection(all_ranks, tran)
return connection
remote_nghrs = list(range(mpi_group.comm.size))
remote_nghrs.remove(mpi_group.comm.rank)
remote_nghrs = list(range(group.nranks))
remote_nghrs.remove(group.my_rank)
if transport == "NVLink":
tran = Transport.CudaIpc
elif transport == "IB":
@@ -131,20 +135,28 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):
else:
assert False
connections = group.make_connection(remote_nghrs, tran)
return group, connections
return connections
def create_group_and_connection(mpi_group: MpiGroup, transport: str):
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink/nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
connection = create_connection(group, transport)
return group, connection
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
create_and_connect(mpi_group, transport)
create_group_and_connection(mpi_group, transport)
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
group, connections = create_and_connect(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, transport)
memory = cp.zeros(nelem, dtype=cp.int32)
nelemPerRank = nelem // group.nranks
sizePerRank = nelemPerRank * memory.itemsize
@@ -185,7 +197,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
if device == "cpu" and transport == "NVLink":
pytest.skip("nvlink doesn't work with host allocated memory")
group, connections = create_and_connect(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, transport)
xp = cp if device == "cuda" else np
if group.my_rank == 0:
memory = xp.random.randn(nelem)
@@ -229,7 +241,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")
group, connections = create_group_and_connection(mpi_group, "IB")
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
for rank in connections:
@@ -242,7 +254,7 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")
group, connections = create_group_and_connection(mpi_group, "IB")
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
@@ -267,6 +279,24 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
group.barrier()
@parametrize_mpi_groups(8)
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
def test_nvls_connection(mpi_group: MpiGroup):
if all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
all_ranks = list(range(group.nranks))
endpoint = EndpointConfig(Transport.Nvls, 2**22)
nvls_connection = group.make_connection(all_ranks, endpoint)
mem_handle1 = nvls_connection.allocate_bind_memory(2**21)
mem_handle2 = nvls_connection.allocate_bind_memory(2**21)
with pytest.raises(Exception):
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
# the memory is freed on the destructor of mem_handle2
mem_handle2 = None
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
class MscclppKernel:
def __init__(
self,
@@ -278,6 +308,8 @@ class MscclppKernel:
use_packet=False,
scratch=None,
fifo=None,
nvls_mem_handle=None,
nvls_buffer_size=None,
):
file_dir = os.path.dirname(os.path.abspath(__file__))
if test_name == "h2d_semaphore":
@@ -316,11 +348,17 @@ class MscclppKernel:
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = 1024
elif test_name == "nvls":
self._kernel = KernelBuilder(
file="nvls_test.cu", kernel_name="nvls_test", file_dir=file_dir
).get_compiled_kernel()
self.nblocks = 64
self.nthreads = 1024
else:
assert False
self.params = b""
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
if semaphore_or_channels != None:
first_arg = next(iter(semaphore_or_channels.values()))
size_of_semaphore_or_channels = len(first_arg.device_handle().raw)
device_handles = []
@@ -333,6 +371,8 @@ class MscclppKernel:
device_handles.append(semaphore_or_channels[rank].device_handle().raw)
# keep a reference to the device handles so that they don't get garbage collected
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
if test_name == "sm_channel":
self.params += pack(tensor.size, use_packet)
@@ -341,9 +381,13 @@ class MscclppKernel:
elif test_name == "fifo":
self.params = fifo.device_handle().raw
elif test_name == "proxy":
semaphore_device_handles = [semaphore.device_handle().raw for semaphore in semaphore_or_channels]
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(semaphore_device_handles)), dtype=cp.uint8)
self.params = pack(my_rank, nranks) + fifo.raw + pack(self._d_semaphore_or_channels)
elif test_name == "nvls":
self.params = (
nvls_mem_handle.device_handle().raw
+ pack(self._d_semaphore_or_channels)
+ pack(my_rank, nranks, nvls_buffer_size)
)
def __call__(self):
return self._kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, None)
@@ -356,7 +400,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
for rank in semaphores:
semaphores[rank].signal()
group, connections = create_and_connect(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, transport)
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
@@ -372,7 +416,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
@parametrize_mpi_groups(2, 4, 8, 16)
def test_d2d_semaphores(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "NVLink")
group, connections = create_group_and_connection(mpi_group, "NVLink")
semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore)
group.barrier()
@@ -386,7 +430,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("use_packet", [False, True])
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
group, connections = create_and_connect(mpi_group, "NVLink")
group, connections = create_group_and_connection(mpi_group, "NVLink")
memory = cp.zeros(nelem, dtype=cp.int32)
if use_packet:
@@ -434,7 +478,7 @@ def test_fifo(
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
group, connections = create_and_connect(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, transport)
memory = cp.zeros(nelem, dtype=cp.int32)
nelemPerRank = nelem // group.nranks
@@ -468,7 +512,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
fifo_device_handle = proxy.fifo_device_handle()
kernel = MscclppKernel(
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=list_sem, fifo=fifo_device_handle
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle
)
proxy.start()
group.barrier()
@@ -484,7 +528,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@pytest.mark.parametrize("use_packet", [False, True])
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_and_connect(mpi_group, transport)
group, connections = create_group_and_connection(mpi_group, transport)
memory = cp.zeros(nelem, dtype=cp.int32)
if use_packet:
@@ -522,3 +566,27 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
proxy_service.stop_proxy()
group.barrier()
assert cp.array_equal(memory, memory_expected)
@parametrize_mpi_groups(4, 8)
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
def test_nvls(mpi_group: MpiGroup):
group, nvls_connection = create_group_and_connection(mpi_group, "NVLS")
nbytes = 2**21
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
nvlinks_connections = create_connection(group, "NVLink")
semaphores = group.make_semaphore(nvlinks_connections, SmDevice2DeviceSemaphore)
kernel = MscclppKernel(
"nvls",
my_rank=group.my_rank,
nranks=group.nranks,
nvls_mem_handle=mem_handle,
nvls_buffer_size=nbytes,
semaphore_or_channels=semaphores,
)
kernel()
cp.cuda.runtime.deviceSynchronize()
group.barrier()