mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-07 00:05:19 +00:00
NVLS support. (#250)
Co-authored-by: Saeed Maleki <saemal@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import os as _os
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
@@ -19,6 +20,7 @@ from ._mscclpp import (
|
||||
Transport,
|
||||
TransportFlags,
|
||||
version,
|
||||
is_nvls_supported,
|
||||
)
|
||||
|
||||
__version__ = version()
|
||||
|
||||
@@ -8,6 +8,7 @@ import cupy as cp
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
EndpointConfig,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
@@ -79,15 +80,21 @@ class CommGroup:
|
||||
assert False # only 8 IBs are supported
|
||||
|
||||
def make_connection(
|
||||
self, remote_ranks: list[int], transports: Transport | dict[int, Transport]
|
||||
self,
|
||||
all_ranks: list[int],
|
||||
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
|
||||
) -> dict[int, Connection]:
|
||||
if type(endpoints) is Transport:
|
||||
endpoints = EndpointConfig(endpoints)
|
||||
if endpoints.transport == Transport.Nvls:
|
||||
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
|
||||
connections = {}
|
||||
for rank in remote_ranks:
|
||||
if type(transports) is dict:
|
||||
transport = transports[rank]
|
||||
for rank in all_ranks:
|
||||
if type(endpoints) is dict:
|
||||
endpoint = endpoints[rank]
|
||||
else:
|
||||
transport = transports
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
|
||||
endpoint = endpoints
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
self.communicator.setup()
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
@@ -72,6 +73,7 @@ void register_core(nb::module_& m) {
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("Nvls", Transport::Nvls)
|
||||
.value("IB0", Transport::IB0)
|
||||
.value("IB1", Transport::IB1)
|
||||
.value("IB2", Transport::IB2)
|
||||
@@ -124,6 +126,24 @@ void register_core(nb::module_& m) {
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
|
||||
.def("get_device_ptr",
|
||||
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
|
||||
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
|
||||
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
|
||||
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
@@ -132,6 +152,7 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
@@ -168,6 +189,7 @@ void register_core(nb::module_& m) {
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("localConfig"))
|
||||
.def("connct_nvls_collective", &Communicator::connctNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
|
||||
@@ -20,4 +20,5 @@ void register_utils(nb::module_& m) {
|
||||
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
|
||||
|
||||
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/proxy_channel_device.hpp>
|
||||
#include <mscclpp/sm_channel_device.hpp>
|
||||
|
||||
@@ -775,3 +776,57 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------
|
||||
// AllReduce6
|
||||
// NVLS
|
||||
// -------------------------------------------
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
|
||||
size_t nelem) {
|
||||
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
|
||||
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
if (bid == 0) {
|
||||
if (tid < nranks - 1) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
|
||||
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
|
||||
|
||||
int my_offset = (tid + bid * blockDim.x) * 4;
|
||||
int my_step = blockDim.x * gridDim.x * 4;
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
|
||||
nvlsPtrs.multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
if (bid == 0) {
|
||||
if (tid < nranks - 1) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -2,12 +2,19 @@
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import cupy as cp
|
||||
from mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5
|
||||
from mscclpp_op import (
|
||||
MscclppAllReduce1,
|
||||
MscclppAllReduce2,
|
||||
MscclppAllReduce3,
|
||||
MscclppAllReduce4,
|
||||
MscclppAllReduce5,
|
||||
MscclppAllReduce6,
|
||||
)
|
||||
from nccl_op import NcclAllReduce
|
||||
from mpi4py import MPI
|
||||
import cupy.cuda.nccl as nccl
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp import ProxyService
|
||||
from mscclpp import ProxyService, is_nvls_supported
|
||||
from prettytable import PrettyTable
|
||||
import netifaces as ni
|
||||
|
||||
@@ -121,6 +128,21 @@ def bench_time(niter: int, func):
|
||||
return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0
|
||||
|
||||
|
||||
def find_best_algo(mscclpp_algos, niter):
|
||||
assert len(mscclpp_algos) > 0
|
||||
best_time = 10000000.0
|
||||
best_algo = None
|
||||
for algo in mscclpp_algos:
|
||||
config, cur_time = find_best_config(algo, niter)
|
||||
if cur_time < best_time:
|
||||
best_time = cur_time
|
||||
best_algo = algo
|
||||
algo.set_params(*config)
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
print(best_algo, end="", flush=True)
|
||||
return best_algo
|
||||
|
||||
|
||||
def find_best_config(mscclpp_call, niter):
|
||||
best_time = 10000000.0
|
||||
for config in mscclpp_call.auto_tune():
|
||||
@@ -133,7 +155,7 @@ def find_best_config(mscclpp_call, niter):
|
||||
best_config = MPI.COMM_WORLD.bcast(best_config, root=0)
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
print(best_config, end="", flush=True)
|
||||
return best_config
|
||||
return best_config, best_time
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
@@ -145,26 +167,27 @@ def run_benchmark(
|
||||
|
||||
proxy_service = None
|
||||
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
||||
proxy_service = ProxyService()
|
||||
if memory.nbytes < 2**20:
|
||||
mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out)
|
||||
elif memory.nbytes < 2**29:
|
||||
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory)
|
||||
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
|
||||
else:
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service)
|
||||
proxy_service.start_proxy()
|
||||
mscclpp_algos = [
|
||||
MscclppAllReduce1(mscclpp_group, memory),
|
||||
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
|
||||
]
|
||||
if is_nvls_supported():
|
||||
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
|
||||
else:
|
||||
if memory.nbytes < 2**22:
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)
|
||||
proxy_service.start_proxy()
|
||||
mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)]
|
||||
else:
|
||||
proxy_service = ProxyService()
|
||||
mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)
|
||||
proxy_service.start_proxy()
|
||||
mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)]
|
||||
|
||||
best_config = find_best_config(mscclpp_call, 20)
|
||||
mscclpp_call.set_params(*best_config)
|
||||
proxy_service.start_proxy()
|
||||
MPI.COMM_WORLD.barrier()
|
||||
mscclpp_call = find_best_algo(mscclpp_algos, 20)
|
||||
if isinstance(mscclpp_call, MscclppAllReduce6):
|
||||
memory = mscclpp_call.get_memory()
|
||||
|
||||
nccl_call = NcclAllReduce(nccl_op, memory)
|
||||
|
||||
@@ -177,13 +200,8 @@ def run_benchmark(
|
||||
nccl_algBw = memory_nbytes / nccl_time / 1e3
|
||||
nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL"
|
||||
|
||||
if (
|
||||
isinstance(mscclpp_call, MscclppAllReduce3)
|
||||
or isinstance(mscclpp_call, MscclppAllReduce5)
|
||||
or isinstance(mscclpp_call, MscclppAllReduce4)
|
||||
):
|
||||
MPI.COMM_WORLD.barrier()
|
||||
proxy_service.stop_proxy()
|
||||
MPI.COMM_WORLD.barrier()
|
||||
proxy_service.stop_proxy()
|
||||
|
||||
speed_up = nccl_time / mscclpp_time
|
||||
if MPI.COMM_WORLD.rank == 0:
|
||||
@@ -247,7 +265,8 @@ if __name__ == "__main__":
|
||||
mscclpp_algbw = []
|
||||
nccl_algbw = []
|
||||
speed_ups = []
|
||||
for i in range(10, 29):
|
||||
end_range = 28 if is_nvls_supported() else 29
|
||||
for i in range(10, end_range):
|
||||
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
|
||||
nelems = 2**i
|
||||
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import cupy as cp
|
||||
import ctypes
|
||||
from mscclpp import Transport, ProxyService
|
||||
from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
|
||||
@@ -418,3 +418,82 @@ class MscclppAllReduce5:
|
||||
for block_size in block_size_to_try:
|
||||
self.set_params(nblocks, block_size)
|
||||
yield nblocks, block_size
|
||||
|
||||
|
||||
class MscclppAllReduce6:
|
||||
def __init__(
|
||||
self,
|
||||
group: mscclpp_comm.CommGroup,
|
||||
nelem: int,
|
||||
memory_dtype: cp.dtype,
|
||||
block_size: int = 1024,
|
||||
nblocks: int = 32,
|
||||
):
|
||||
self.group = group
|
||||
datatype_size = memory_dtype().itemsize
|
||||
buffer_size = nelem * datatype_size
|
||||
type_str = type_to_str(memory_dtype)
|
||||
all_ranks = list(range(group.nranks))
|
||||
remote_nghrs = all_ranks.copy()
|
||||
remote_nghrs.remove(self.group.my_rank)
|
||||
|
||||
self.group.barrier()
|
||||
# create a connection for each remote neighbor
|
||||
self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
|
||||
self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls)
|
||||
min_gran = self.nvls_connection.get_multicast_min_granularity()
|
||||
aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran)
|
||||
self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory(
|
||||
aligned_buffer_size
|
||||
) # just using recommended size for now
|
||||
self.memory_ptr = self.nvls_mem_handle.get_device_ptr()
|
||||
|
||||
self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0)
|
||||
self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr)
|
||||
|
||||
# create a sm_channel for each remote neighbor
|
||||
self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore)
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.kernel = KernelBuilder(
|
||||
file="allreduce.cu",
|
||||
kernel_name="allreduce6",
|
||||
file_dir=file_dir,
|
||||
macro_dict={"TYPE": type_str},
|
||||
).get_compiled_kernel()
|
||||
self.device_handles = []
|
||||
for rank in range(self.group.nranks):
|
||||
if rank != self.group.my_rank:
|
||||
self.device_handles.append(self.semaphores[rank].device_handle().raw)
|
||||
|
||||
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
|
||||
self.nvls_handle = self.nvls_mem_handle.device_handle().raw
|
||||
|
||||
self.set_params(nblocks, block_size)
|
||||
|
||||
def get_memory(self):
|
||||
return self.memory
|
||||
|
||||
def __call__(self, stream_ptr):
|
||||
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
|
||||
return self.memory
|
||||
|
||||
def set_params(self, nblocks, block_size):
|
||||
self.nblocks = nblocks
|
||||
self.block_size = block_size
|
||||
self.params = b""
|
||||
self.params += pack(
|
||||
self.device_handles_cp,
|
||||
self.nvls_handle,
|
||||
self.memory,
|
||||
self.group.my_rank,
|
||||
self.group.nranks,
|
||||
ctypes.c_size_t(self.memory.size),
|
||||
)
|
||||
|
||||
def auto_tune(self):
|
||||
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
|
||||
block_size_to_try = [256, 512, 1024]
|
||||
for nblocks in nblocks_to_try:
|
||||
for block_size in block_size_to_try:
|
||||
self.set_params(nblocks, block_size)
|
||||
yield nblocks, block_size
|
||||
|
||||
66
python/test/nvls_test.cu
Normal file
66
python/test/nvls_test.cu
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/poll_device.hpp>
|
||||
#include <mscclpp/semaphore_device.hpp>
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs,
|
||||
mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
|
||||
int nelem = nbytes / sizeof(float);
|
||||
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
|
||||
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
|
||||
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
|
||||
dev_ptr[idx] = my_rank;
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
if (bid == 0) {
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
|
||||
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
|
||||
|
||||
int my_offset = (tid + bid * blockDim.x) * 4;
|
||||
int my_step = blockDim.x * gridDim.x * 4;
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
|
||||
nvlsPtrs.multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
if (bid == 0) {
|
||||
if (tid < nranks && tid != my_rank) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) {
|
||||
if (dev_ptr[idx] != ((nranks * (nranks - 1)) / 2)) {
|
||||
__assert_fail("dev_ptr[idx] != nranks", __FILE__, __LINE__, __PRETTY_FUNCTION__);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import (
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
@@ -19,6 +20,7 @@ from mscclpp import (
|
||||
SmDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
is_nvls_supported,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
@@ -117,13 +119,15 @@ def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
|
||||
mpi_group.comm.barrier()
|
||||
|
||||
|
||||
def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
def create_connection(group: mscclpp_comm.CommGroup, transport: str):
|
||||
if transport == "NVLS":
|
||||
all_ranks = list(range(group.nranks))
|
||||
tran = Transport.Nvls
|
||||
connection = group.make_connection(all_ranks, tran)
|
||||
return connection
|
||||
|
||||
remote_nghrs = list(range(mpi_group.comm.size))
|
||||
remote_nghrs.remove(mpi_group.comm.rank)
|
||||
remote_nghrs = list(range(group.nranks))
|
||||
remote_nghrs.remove(group.my_rank)
|
||||
if transport == "NVLink":
|
||||
tran = Transport.CudaIpc
|
||||
elif transport == "IB":
|
||||
@@ -131,20 +135,28 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):
|
||||
else:
|
||||
assert False
|
||||
connections = group.make_connection(remote_nghrs, tran)
|
||||
return group, connections
|
||||
return connections
|
||||
|
||||
|
||||
def create_group_and_connection(mpi_group: MpiGroup, transport: str):
|
||||
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink/nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
connection = create_connection(group, transport)
|
||||
return group, connection
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
|
||||
create_and_connect(mpi_group, transport)
|
||||
create_group_and_connection(mpi_group, transport)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
sizePerRank = nelemPerRank * memory.itemsize
|
||||
@@ -185,7 +197,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
if device == "cpu" and transport == "NVLink":
|
||||
pytest.skip("nvlink doesn't work with host allocated memory")
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
xp = cp if device == "cuda" else np
|
||||
if group.my_rank == 0:
|
||||
memory = xp.random.randn(nelem)
|
||||
@@ -229,7 +241,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
for rank in connections:
|
||||
@@ -242,7 +254,7 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "IB")
|
||||
group, connections = create_group_and_connection(mpi_group, "IB")
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
|
||||
|
||||
@@ -267,6 +279,24 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
|
||||
group.barrier()
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
|
||||
def test_nvls_connection(mpi_group: MpiGroup):
|
||||
if all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
all_ranks = list(range(group.nranks))
|
||||
endpoint = EndpointConfig(Transport.Nvls, 2**22)
|
||||
nvls_connection = group.make_connection(all_ranks, endpoint)
|
||||
mem_handle1 = nvls_connection.allocate_bind_memory(2**21)
|
||||
mem_handle2 = nvls_connection.allocate_bind_memory(2**21)
|
||||
with pytest.raises(Exception):
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
# the memory is freed on the destructor of mem_handle2
|
||||
mem_handle2 = None
|
||||
mem_handle3 = nvls_connection.allocate_bind_memory(2**21)
|
||||
|
||||
|
||||
class MscclppKernel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -278,6 +308,8 @@ class MscclppKernel:
|
||||
use_packet=False,
|
||||
scratch=None,
|
||||
fifo=None,
|
||||
nvls_mem_handle=None,
|
||||
nvls_buffer_size=None,
|
||||
):
|
||||
file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if test_name == "h2d_semaphore":
|
||||
@@ -316,11 +348,17 @@ class MscclppKernel:
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = 1024
|
||||
elif test_name == "nvls":
|
||||
self._kernel = KernelBuilder(
|
||||
file="nvls_test.cu", kernel_name="nvls_test", file_dir=file_dir
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 64
|
||||
self.nthreads = 1024
|
||||
else:
|
||||
assert False
|
||||
|
||||
self.params = b""
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
|
||||
if semaphore_or_channels != None:
|
||||
first_arg = next(iter(semaphore_or_channels.values()))
|
||||
size_of_semaphore_or_channels = len(first_arg.device_handle().raw)
|
||||
device_handles = []
|
||||
@@ -333,6 +371,8 @@ class MscclppKernel:
|
||||
device_handles.append(semaphore_or_channels[rank].device_handle().raw)
|
||||
# keep a reference to the device handles so that they don't get garbage collected
|
||||
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
|
||||
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
|
||||
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
|
||||
if test_name == "sm_channel":
|
||||
self.params += pack(tensor.size, use_packet)
|
||||
@@ -341,9 +381,13 @@ class MscclppKernel:
|
||||
elif test_name == "fifo":
|
||||
self.params = fifo.device_handle().raw
|
||||
elif test_name == "proxy":
|
||||
semaphore_device_handles = [semaphore.device_handle().raw for semaphore in semaphore_or_channels]
|
||||
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(semaphore_device_handles)), dtype=cp.uint8)
|
||||
self.params = pack(my_rank, nranks) + fifo.raw + pack(self._d_semaphore_or_channels)
|
||||
elif test_name == "nvls":
|
||||
self.params = (
|
||||
nvls_mem_handle.device_handle().raw
|
||||
+ pack(self._d_semaphore_or_channels)
|
||||
+ pack(my_rank, nranks, nvls_buffer_size)
|
||||
)
|
||||
|
||||
def __call__(self):
|
||||
return self._kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, None)
|
||||
@@ -356,7 +400,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
for rank in semaphores:
|
||||
semaphores[rank].signal()
|
||||
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
|
||||
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
|
||||
@@ -372,7 +416,7 @@ def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
group, connections = create_and_connect(mpi_group, "NVLink")
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore)
|
||||
group.barrier()
|
||||
@@ -386,7 +430,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, "NVLink")
|
||||
group, connections = create_group_and_connection(mpi_group, "NVLink")
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
if use_packet:
|
||||
@@ -434,7 +478,7 @@ def test_fifo(
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
|
||||
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
nelemPerRank = nelem // group.nranks
|
||||
@@ -468,7 +512,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
fifo_device_handle = proxy.fifo_device_handle()
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=list_sem, fifo=fifo_device_handle
|
||||
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle
|
||||
)
|
||||
proxy.start()
|
||||
group.barrier()
|
||||
@@ -484,7 +528,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_and_connect(mpi_group, transport)
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
if use_packet:
|
||||
@@ -522,3 +566,27 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
|
||||
proxy_service.stop_proxy()
|
||||
group.barrier()
|
||||
assert cp.array_equal(memory, memory_expected)
|
||||
|
||||
|
||||
@parametrize_mpi_groups(4, 8)
|
||||
@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported")
|
||||
def test_nvls(mpi_group: MpiGroup):
|
||||
group, nvls_connection = create_group_and_connection(mpi_group, "NVLS")
|
||||
nbytes = 2**21
|
||||
mem_handle = nvls_connection.allocate_bind_memory(nbytes)
|
||||
|
||||
nvlinks_connections = create_connection(group, "NVLink")
|
||||
semaphores = group.make_semaphore(nvlinks_connections, SmDevice2DeviceSemaphore)
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"nvls",
|
||||
my_rank=group.my_rank,
|
||||
nranks=group.nranks,
|
||||
nvls_mem_handle=mem_handle,
|
||||
nvls_buffer_size=nbytes,
|
||||
semaphore_or_channels=semaphores,
|
||||
)
|
||||
|
||||
kernel()
|
||||
cp.cuda.runtime.deviceSynchronize()
|
||||
group.barrier()
|
||||
|
||||
Reference in New Issue
Block a user