Auto-tune vector sizes for NVLS allreduce6 (#338)

Also fixes bugs in MscclppAllReduce6

Below is the performance when the algorithm is fixed to
MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2.

Float16:

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us)
| NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 |
| 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 |
| 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 |
| 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 |
| 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 |
| 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 |
| 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 |
| 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 |
| 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 |
| 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 |
| 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 |
| 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 |
| 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 |
| 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 |
| 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 |
| 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 |
| 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 |
| 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
 
 Float32:

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us)
| NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
| 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 |
| 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 |
| 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 |
| 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 |
| 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 |
| 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 |
| 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 |
| 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 |
| 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 |
| 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 |
| 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 |
| 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 |
| 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 |
| 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 |
| 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 |
| 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 |
| 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 |
| 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 |

+-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
This commit is contained in:
Roshan Dathathri
2024-08-15 20:11:54 -07:00
committed by GitHub
parent ead4efc315
commit 7ed13ec4b5
6 changed files with 114 additions and 42 deletions

View File

@@ -25,27 +25,32 @@ struct DeviceMulticastPointerDeviceHandle {
size_t bufferSize;
#if defined(MSCCLPP_DEVICE_CUDA)
template <typename TValue = float4, typename T = float>
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f32 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f32 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
@@ -53,21 +58,24 @@ struct DeviceMulticastPointerDeviceHandle {
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.f32 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.st.relaxed.sys.global.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
@@ -75,21 +83,24 @@ struct DeviceMulticastPointerDeviceHandle {
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, float4> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.f32 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x),
"r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}

View File

@@ -6,6 +6,7 @@ import os as _os
from ._mscclpp import (
Communicator,
Connection,
connect_nvls_collective,
EndpointConfig,
Fifo,
Host2DeviceSemaphore,

View File

@@ -8,6 +8,7 @@ import cupy as cp
from ._mscclpp import (
Communicator,
Connection,
connect_nvls_collective,
EndpointConfig,
Host2DeviceSemaphore,
Host2HostSemaphore,

View File

@@ -788,7 +788,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// Barrier among all devices followed by a memory fence
// Barrier among all devices
// Should be called by all threads on all devices
// Assumes \p num_threads_per_block >= \p num_ranks
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
@@ -806,36 +806,78 @@ __forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceH
deviceSyncer.sync(num_blocks);
}
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
size_t nelem) {
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
template <typename DataType = float, int kVecSize = 8>
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
int num_ranks, size_t num_elements) {
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_threads_per_block = blockDim.x;
int num_blocks = gridDim.x;
// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
// before reading them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);
barrier(semaphores, tid, bid, num_blocks, num_ranks);
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
// every device loads, reduces, and stores a partition of the multicast memory
int rank_start = ((int64_t)num_elements * (int64_t)my_rank) / (int64_t)num_ranks;
int rank_end = ((int64_t)num_elements * (int64_t)(my_rank + 1)) / (int64_t)num_ranks;
int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;
int thread_offset = (bid * num_threads_per_block + tid) * kVecSize;
int thread_step = (num_threads_per_block * num_blocks) * kVecSize; // number of threads * vector size
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
for (int idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
if constexpr (std::is_same_v<DataType, float> && (kVecSize == 4)) {
uint4 val; // fits 4 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 2)) {
uint2 val; // fits 2 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 1)) {
uint1 val; // fits 1 float element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 8)) {
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 4)) {
uint2 val; // fits 4 cutlass::half_t elements; i.e., 2 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 2)) {
uint1 val; // fits 2 cutlass::half_t elements; i.e., 1 half2 element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else {
// not supported: cannot use static_assert because of the way TYPE is handled in this file
assert(false); // Unsupported data type and vector size combination
}
}
// end with a barrier to ensure all devices can now read their values
// from their own memory (that is part of the multicast memory)
// after writing them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);
barrier(semaphores, tid, bid, num_blocks, num_ranks);
}
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
size_t vector_size) {
if (vector_size == 8) {
allreduce6_helper<TYPE, 8>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else if (vector_size == 4) {
allreduce6_helper<TYPE, 4>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else if (vector_size == 2) {
allreduce6_helper<TYPE, 2>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else {
allreduce6_helper<TYPE, 1>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
}
}
#endif

View File

@@ -175,7 +175,7 @@ def run_benchmark(
MscclppAllReduce1(mscclpp_group, memory),
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
]
if is_nvls_supported():
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
else:
if memory.nbytes < 2**22:

View File

@@ -468,7 +468,16 @@ class MscclppAllReduce6:
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
self.nvls_handle = self.nvls_mem_handle.device_handle().raw
self.set_params(nblocks, block_size)
if self.memory.dtype != cp.float16 and self.memory.dtype != cp.float32:
raise RuntimeError("Unsupported data type")
if self.memory.dtype == cp.float16:
vector_size = 8
elif self.memory.dtype == cp.float32:
vector_size = 4
else:
vector_size = 1
self.set_params(nblocks, block_size, vector_size)
def get_memory(self):
return self.memory
@@ -477,23 +486,31 @@ class MscclppAllReduce6:
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
return self.memory
def set_params(self, nblocks, block_size):
def set_params(self, nblocks, block_size, vector_size):
self.nblocks = nblocks
self.block_size = block_size
self.vector_size = vector_size
self.params = b""
self.params += pack(
self.device_handles_cp,
self.nvls_handle,
self.memory,
self.group.my_rank,
self.group.nranks,
ctypes.c_size_t(self.memory.size),
self.vector_size,
)
def auto_tune(self):
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
block_size_to_try = [256, 512, 1024]
if self.memory.dtype == cp.float16:
vector_size_to_try = [8, 4, 2]
elif self.memory.dtype == cp.float32:
vector_size_to_try = [4, 2, 1]
else:
vector_size_to_try = [1]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size
for vector_size in vector_size_to_try:
self.set_params(nblocks, block_size, vector_size)
yield nblocks, block_size, vector_size