mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Also fixes bugs in MscclppAllReduce6 Below is the performance when the algorithm is fixed to MscclppAllReduce6 on 8 H100 GPUs connected with NVLink using CUDA 12.2. Float16: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp16) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 2.0 KiB | 11.15 | 0.18 | PASS | 13.82 | 0.15 | PASS | 1.24 | | 4.0 KiB | 11.15 | 0.37 | PASS | 14.74 | 0.28 | PASS | 1.32 | | 8.0 KiB | 11.14 | 0.74 | PASS | 15.17 | 0.54 | PASS | 1.36 | | 16.0 KiB | 11.16 | 1.47 | PASS | 15.77 | 1.04 | PASS | 1.41 | | 32.0 KiB | 11.15 | 2.94 | PASS | 17.50 | 1.87 | PASS | 1.57 | | 64.0 KiB | 11.18 | 5.86 | PASS | 17.64 | 3.71 | PASS | 1.58 | | 128.0 KiB | 11.16 | 11.74 | PASS | 17.83 | 7.35 | PASS | 1.60 | | 256.0 KiB | 11.21 | 23.38 | PASS | 18.00 | 14.57 | PASS | 1.60 | | 512.0 KiB | 11.70 | 44.81 | PASS | 18.42 | 28.46 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.87 | PASS | 20.23 | 51.83 | PASS | 1.48 | | 2.0 MiB | 17.29 | 121.27 | PASS | 31.60 | 66.36 | PASS | 1.83 | | 4.0 MiB | 25.26 | 166.02 | PASS | 38.74 | 108.26 | PASS | 1.53 | | 8.0 MiB | 40.17 | 208.83 | PASS | 62.86 | 133.45 | PASS | 1.56 | | 16.0 MiB | 70.92 | 236.56 | PASS | 113.36 | 147.99 | PASS | 1.60 | | 32.0 MiB | 131.38 | 255.41 | PASS | 203.21 | 165.13 | PASS | 1.55 | | 64.0 MiB | 253.39 | 264.84 | PASS | 342.12 | 196.15 | PASS | 1.35 | | 128.0 MiB | 496.74 | 270.20 | PASS | 670.62 | 200.14 | PASS | 1.35 | | 256.0 MiB | 982.42 | 273.24 | PASS | 1318.36 | 203.61 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ Float32: +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | Size (fp32) | Time (us) | AlgBW (GB/s) | Correctness | NCCL Time (us) | NCCL AlgBW (GB/s) | NCCL Correctness | Speed Up | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+ | 4.0 KiB | 11.04 | 0.37 | PASS | 14.79 | 0.28 | PASS | 1.34 | | 8.0 KiB | 11.15 | 0.73 | PASS | 15.25 | 0.54 | PASS | 1.37 | | 16.0 KiB | 11.12 | 1.47 | PASS | 15.87 | 1.03 | PASS | 1.43 | | 32.0 KiB | 11.13 | 2.95 | PASS | 17.21 | 1.90 | PASS | 1.55 | | 64.0 KiB | 11.11 | 5.90 | PASS | 17.37 | 3.77 | PASS | 1.56 | | 128.0 KiB | 11.08 | 11.83 | PASS | 17.54 | 7.47 | PASS | 1.58 | | 256.0 KiB | 11.15 | 23.50 | PASS | 17.71 | 14.80 | PASS | 1.59 | | 512.0 KiB | 11.56 | 45.34 | PASS | 18.21 | 28.79 | PASS | 1.57 | | 1.0 MiB | 13.64 | 76.90 | PASS | 19.87 | 52.77 | PASS | 1.46 | | 2.0 MiB | 17.24 | 121.67 | PASS | 31.63 | 66.30 | PASS | 1.84 | | 4.0 MiB | 25.19 | 166.47 | PASS | 38.63 | 108.57 | PASS | 1.53 | | 8.0 MiB | 40.38 | 207.72 | PASS | 62.65 | 133.89 | PASS | 1.55 | | 16.0 MiB | 70.72 | 237.23 | PASS | 114.57 | 146.44 | PASS | 1.62 | | 32.0 MiB | 131.49 | 255.18 | PASS | 200.79 | 167.11 | PASS | 1.53 | | 64.0 MiB | 253.98 | 264.23 | PASS | 342.58 | 195.89 | PASS | 1.35 | | 128.0 MiB | 496.96 | 270.08 | PASS | 670.64 | 200.13 | PASS | 1.35 | | 256.0 MiB | 982.83 | 273.12 | PASS | 1318.90 | 203.53 | PASS | 1.34 | | 512.0 MiB | 1954.07 | 274.75 | PASS | 2609.04 | 205.77 | PASS | 1.34 | +-------------+-----------+--------------+-------------+----------------+-------------------+------------------+----------+
884 lines
36 KiB
Plaintext
884 lines
36 KiB
Plaintext
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT license.
|
|
|
|
#if defined(__HIP_PLATFORM_AMD__)
|
|
#include <hip/hip_fp16.h>
|
|
#else
|
|
#include <cuda_fp16.h>
|
|
#endif
|
|
|
|
#include <mscclpp/concurrency_device.hpp>
|
|
#include <mscclpp/nvls_device.hpp>
|
|
#include <mscclpp/proxy_channel_device.hpp>
|
|
#include <mscclpp/sm_channel_device.hpp>
|
|
|
|
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
|
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
|
|
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
|
|
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
|
|
|
|
#ifndef TYPE
|
|
#define TYPE float
|
|
#endif
|
|
|
|
#define VECTOR_SIZE (sizeof(int4) / sizeof(TYPE))
|
|
|
|
template <typename To, typename From>
|
|
__forceinline__ __device__ To bit_cast(const From& src) {
|
|
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
|
|
|
|
union {
|
|
From f;
|
|
To t;
|
|
} u;
|
|
u.f = src;
|
|
return u.t;
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ T add_elements(T a, T b) {
|
|
return a + b;
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
|
return __hadd2(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
|
|
int4 ret;
|
|
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
|
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
|
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
|
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
|
return ret;
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
|
|
return add_vectors_helper<T>(a, b);
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
|
|
return add_vectors_helper<__half2>(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
|
|
uint2 ret;
|
|
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
|
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
|
return ret;
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
|
|
return add_vectors_helper<T>(a, b);
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
|
|
return add_vectors_helper<__half2>(a, b);
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
|
|
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
|
}
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ int add_vectors(int a, int b) {
|
|
return add_vectors_helper<T>(a, b);
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
|
|
return add_vectors_helper<__half2>(a, b);
|
|
}
|
|
|
|
__forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem, int blockId, int nBlocks) {
|
|
size_t nInt4 = nElem / 4;
|
|
size_t nLastInts = nElem % 4;
|
|
int4* dst4 = (int4*)dst;
|
|
int4* src4 = (int4*)src;
|
|
for (int i = threadIdx.x + blockId * blockDim.x; i < nInt4; i += blockDim.x * nBlocks) {
|
|
dst4[i] = add_vectors<TYPE>(dst4[i], src4[i]);
|
|
}
|
|
if (nLastInts > 0) {
|
|
int* dstLast = ((int*)dst) + nInt4 * 4;
|
|
int* srcLast = ((int*)src) + nInt4 * 4;
|
|
for (int i = threadIdx.x + blockId * blockDim.x; i < nLastInts; i += blockDim.x * nBlocks) {
|
|
dstLast[i] = add_vectors<TYPE>(dstLast[i], srcLast[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
__forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
|
|
vectorSum(dst, src, nElem, blockIdx.x, gridDim.x);
|
|
}
|
|
|
|
// -------------------------------------------
|
|
// AllReduce1
|
|
// -------------------------------------------
|
|
|
|
template <int READ_ONLY>
|
|
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
|
|
size_t nelems) {
|
|
const size_t chunkSize = nelems / nranks;
|
|
if (nranks == 1) return;
|
|
const int nPeer = nranks - 1;
|
|
const size_t indexOffset = rank * chunkSize;
|
|
const size_t indexOffset4 = indexOffset / VECTOR_SIZE;
|
|
int4* buff4 = (int4*)buff;
|
|
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
// synchronize everyone
|
|
if (tid == 0) {
|
|
__threadfence_system();
|
|
}
|
|
__syncthreads();
|
|
if (tid < nPeer) {
|
|
smChans[tid].relaxedSignal();
|
|
}
|
|
if (tid >= nPeer && tid < nPeer * 2) {
|
|
smChans[tid - nPeer].wait();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
// use int4 as much as possible
|
|
const size_t nInt4 = chunkSize / VECTOR_SIZE;
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * gridDim.x) {
|
|
int4 tmp = buff4[indexOffset4 + idx];
|
|
for (int index = 0; index < nPeer; ++index) {
|
|
int4 val;
|
|
int peerIdx = (index + rank);
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
|
tmp = add_vectors<TYPE>(tmp, val);
|
|
}
|
|
if (READ_ONLY == 0) {
|
|
for (int index = 0; index < nPeer; ++index) {
|
|
int peerIdx = (index + rank);
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
smChans[peerIdx].write<int4>(indexOffset4 + idx, tmp);
|
|
}
|
|
}
|
|
buff4[indexOffset4 + idx] = tmp;
|
|
}
|
|
|
|
// use the given TYPE for the rest
|
|
size_t processed = nInt4 * VECTOR_SIZE * nranks;
|
|
const size_t nRemElems = nelems - processed;
|
|
const size_t startIdx = processed + (nRemElems * rank) / nranks;
|
|
const size_t endIdx = processed + (nRemElems * (rank + 1)) / nranks;
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x + startIdx; idx < endIdx; idx += blockDim.x * gridDim.x) {
|
|
TYPE tmp = buff[idx];
|
|
for (int index = 0; index < nPeer; ++index) {
|
|
int peerIdx = (index + rank);
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
TYPE val = smChans[peerIdx].read<TYPE>(idx);
|
|
tmp += val;
|
|
}
|
|
if (READ_ONLY == 0) {
|
|
for (int index = 0; index < nPeer; ++index) {
|
|
int peerIdx = (index + rank);
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
smChans[peerIdx].write<TYPE>(idx, tmp);
|
|
}
|
|
}
|
|
buff[idx] = tmp;
|
|
}
|
|
|
|
// synchronize everyone again
|
|
deviceSyncer.sync(gridDim.x);
|
|
if (tid == 0) {
|
|
__threadfence_system();
|
|
}
|
|
__syncthreads();
|
|
if (tid < nPeer) {
|
|
smChans[tid].relaxedSignal();
|
|
}
|
|
if (tid >= nPeer && tid < nPeer * 2) {
|
|
smChans[tid - nPeer].wait();
|
|
}
|
|
|
|
if (READ_ONLY) {
|
|
deviceSyncer.sync(gridDim.x);
|
|
for (int i = 0; i < nPeer; ++i) {
|
|
int peerIdx = (i + rank);
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
const int remoteRank = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
|
size_t offset = chunkSize * remoteRank * sizeof(TYPE);
|
|
smChans[peerIdx].get(offset, chunkSize * sizeof(TYPE), tid, blockDim.x * gridDim.x);
|
|
}
|
|
}
|
|
}
|
|
|
|
extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
|
|
int rank, int nranks, size_t nelems, int read_only) {
|
|
if (read_only)
|
|
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
|
|
else
|
|
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
|
|
}
|
|
|
|
// -------------------------------------------
|
|
// AllReduce2
|
|
// -------------------------------------------
|
|
|
|
__device__ uint64_t globalFlag = 1;
|
|
|
|
extern "C" __global__ void __launch_bounds__(1024, 1)
|
|
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
|
|
int worldSize, size_t nelems) {
|
|
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
|
// This version of allreduce only works for single nodes
|
|
const int nPeers = worldSize - 1;
|
|
const size_t nPkts = nelems / 2;
|
|
const int nelemsPerRank = nelems / worldSize;
|
|
const int nPktsPerRank = nelemsPerRank / 2;
|
|
// flag for packets. Initially 1
|
|
const uint32_t flag = (uint32_t)globalFlag;
|
|
// thread block & channel info
|
|
const int nBlocksPerPeer = gridDim.x / nPeers;
|
|
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
|
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
|
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
|
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
|
|
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
|
// double buffering
|
|
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
|
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
|
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
|
|
size_t scratchResultOffset =
|
|
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
|
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
|
|
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
|
|
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
|
|
|
// step 1: write to scratch buffer
|
|
smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
|
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
|
uint2 data = make_uint2(0, 0);
|
|
for (int index = 0; index < nPeers; index++) {
|
|
const int remoteRank = index < rank ? index : index + 1;
|
|
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
|
uint2 val = dstPkt[idx].read(flag);
|
|
data = add_vectors<TYPE>(val, data);
|
|
}
|
|
data = add_vectors<TYPE>(data, src[idx]);
|
|
dst[idx] = data;
|
|
|
|
mscclpp::LLPacket packet;
|
|
packet.data1 = data.x;
|
|
packet.flag1 = flag;
|
|
packet.data2 = data.y;
|
|
packet.flag2 = flag;
|
|
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
|
for (int index = 0; index < nPeers; index++) {
|
|
smChans[index].write(offset, packet);
|
|
}
|
|
}
|
|
// step 3: get data result from scratch buffer
|
|
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
|
const int dstOffset = remoteRank * nPktsPerRank;
|
|
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
|
|
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
|
|
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
|
result[idx].x = data.x;
|
|
result[idx].y = data.y;
|
|
}
|
|
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
|
globalFlag += 1;
|
|
}
|
|
}
|
|
|
|
// -------------------------------------------
|
|
// AllReduce3
|
|
// -------------------------------------------
|
|
|
|
extern "C" __global__ void __launch_bounds__(1024, 1)
|
|
allreduce3(mscclpp::SimpleProxyChannelDeviceHandle* fstRoundChans,
|
|
mscclpp::SimpleProxyChannelDeviceHandle* sndRoundChans, TYPE* buff, TYPE* scratch, int rank,
|
|
int worldSize, size_t nelems) {
|
|
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
|
|
|
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
|
|
int remoteSendRank = (rank + 1) % worldSize;
|
|
int remoteRecvRank = (rank + worldSize - 1) % worldSize;
|
|
int peerSendId = (remoteSendRank < rank) ? remoteSendRank : remoteSendRank - 1;
|
|
int peerRecvId = (remoteRecvRank < rank) ? remoteRecvRank : remoteRecvRank - 1;
|
|
|
|
mscclpp::SimpleProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
|
|
mscclpp::SimpleProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
|
|
mscclpp::SimpleProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
|
|
mscclpp::SimpleProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
|
|
|
|
// Step 1
|
|
size_t chunkIndex = (rank + worldSize - 1) % worldSize;
|
|
size_t chunkNelem = nelems / worldSize;
|
|
size_t offset = chunkIndex * chunkNelem * sizeof(int);
|
|
if (isComm) {
|
|
if (chunkNelem > 1) {
|
|
devFstSendChan.putWithSignal(offset, chunkNelem / 2 * sizeof(int));
|
|
}
|
|
}
|
|
|
|
// Step 2 ~ Step n-1
|
|
for (int step = 2; step < worldSize; ++step) {
|
|
if (isComm) {
|
|
if (chunkNelem > 1) {
|
|
devFstRecvChan.wait();
|
|
devFstSendChan.flush();
|
|
}
|
|
devFstSendChan.putWithSignal(offset + chunkNelem / 2 * sizeof(int), (chunkNelem - chunkNelem / 2) * sizeof(int));
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
// Reduce
|
|
chunkIndex = (rank + worldSize - step) % worldSize;
|
|
offset = chunkIndex * chunkNelem * sizeof(int);
|
|
int* dst = (int*)((char*)buff + offset);
|
|
int* src = (int*)((char*)scratch + offset);
|
|
vectorSum((TYPE*)dst, (TYPE*)src, chunkNelem / 2);
|
|
|
|
if (isComm) {
|
|
devFstRecvChan.wait();
|
|
devFstSendChan.flush();
|
|
if (chunkNelem > 1) {
|
|
devFstSendChan.putWithSignal(offset, chunkNelem / 2 * sizeof(int));
|
|
}
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
dst += chunkNelem / 2;
|
|
src += chunkNelem / 2;
|
|
vectorSum((TYPE*)dst, (TYPE*)src, chunkNelem - chunkNelem / 2);
|
|
}
|
|
|
|
// Step n
|
|
if (isComm) {
|
|
if (chunkNelem > 1) {
|
|
devFstRecvChan.wait();
|
|
devFstSendChan.flush();
|
|
}
|
|
devFstSendChan.putWithSignal(offset + chunkNelem / 2 * sizeof(int), (chunkNelem - chunkNelem / 2) * sizeof(int));
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
offset = rank * chunkNelem * sizeof(int);
|
|
int* dst = (int*)((char*)buff + offset);
|
|
int* src = (int*)((char*)scratch + offset);
|
|
vectorSum((TYPE*)dst, (TYPE*)src, chunkNelem / 2);
|
|
|
|
if (isComm) {
|
|
devFstRecvChan.wait();
|
|
devFstSendChan.flush();
|
|
if (chunkNelem > 1) {
|
|
devSndSendChan.putWithSignal(offset, chunkNelem / 2 * sizeof(int));
|
|
}
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
dst += chunkNelem / 2;
|
|
src += chunkNelem / 2;
|
|
vectorSum((TYPE*)dst, (TYPE*)src, chunkNelem - chunkNelem / 2);
|
|
|
|
if (isComm) {
|
|
if (chunkNelem > 1) {
|
|
devSndRecvChan.wait();
|
|
devSndSendChan.flush();
|
|
}
|
|
devSndSendChan.putWithSignalAndFlush(offset + chunkNelem / 2 * sizeof(int),
|
|
(chunkNelem - chunkNelem / 2) * sizeof(int));
|
|
}
|
|
|
|
// Step n+1 ~ Step 2n-2
|
|
for (int i = 1; i < worldSize - 1; ++i) {
|
|
if (isComm) {
|
|
devSndRecvChan.wait();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
// Copy
|
|
chunkIndex = (rank + worldSize - i) % worldSize;
|
|
if (isComm) {
|
|
devSndSendChan.putWithSignalAndFlush(chunkIndex * chunkNelem * sizeof(int), chunkNelem * sizeof(int));
|
|
}
|
|
}
|
|
|
|
// Final receive
|
|
if (isComm) {
|
|
devSndRecvChan.wait();
|
|
}
|
|
}
|
|
|
|
// -------------------------------------------
|
|
// AllReduce4
|
|
// 2-node
|
|
// -------------------------------------------
|
|
__device__ void localReduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nRanksPerNode,
|
|
int startChunkIndex, size_t offsetInChunk, size_t chunkSize, size_t nelems,
|
|
int nBlocks) {
|
|
if (nRanksPerNode == 1) return;
|
|
if (blockIdx.x >= nBlocks) return;
|
|
const int nPeer = nRanksPerNode - 1;
|
|
|
|
const size_t localRankIndexInNode = rank % nRanksPerNode;
|
|
const size_t indexOffset = ((localRankIndexInNode + startChunkIndex) * chunkSize + offsetInChunk);
|
|
const size_t indexOffset4 = indexOffset / 4;
|
|
|
|
int4* buff4 = (int4*)buff;
|
|
|
|
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
|
|
smChans[peerIdx].relaxedSignal();
|
|
}
|
|
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
|
|
smChans[peerIdx].wait();
|
|
}
|
|
reduceScatterDeviceSyncer.sync(nBlocks);
|
|
|
|
const size_t nInt4 = nelems / 4;
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
|
|
int4 tmp = buff4[indexOffset4 + idx];
|
|
for (int index = 0; index < nPeer; ++index) {
|
|
int4 val;
|
|
int peerIdx = index + localRankIndexInNode;
|
|
if (peerIdx >= nPeer) peerIdx -= nPeer;
|
|
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
|
tmp = add_vectors<TYPE>(tmp, val);
|
|
}
|
|
buff4[indexOffset4 + idx] = tmp;
|
|
}
|
|
|
|
// TODO: deal with rest elements
|
|
}
|
|
|
|
// This kernel is the most performant when the number of blocks is a multiple of (nRanksPerNode - 1).
|
|
__device__ void localAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
|
|
int startRankChunkIndex, uint64_t offsetInRankChunk, uint64_t rankChunkSize,
|
|
uint64_t size, size_t nBlocks) {
|
|
if (nRanksPerNode == 1) return;
|
|
if (blockIdx.x >= nBlocks) return;
|
|
const size_t nPeer = nRanksPerNode - 1;
|
|
const size_t peerIdx = blockIdx.x % nPeer;
|
|
const size_t nBlockForThisPeer = nBlocks / nPeer + (nBlocks % nPeer > peerIdx ? 1 : 0);
|
|
const size_t peerLocalBlockIdx = blockIdx.x / nPeer;
|
|
const size_t rankLocalIndex = rank % nRanksPerNode;
|
|
const int remoteRankLocalIndex = (peerIdx < rankLocalIndex ? peerIdx : peerIdx + 1);
|
|
|
|
// Split the data into chunks for aligned data access. Ignore the remainder here and let the last block handle it.
|
|
constexpr size_t chunkBytes = 128; // heuristic value
|
|
const size_t nChunk = size / chunkBytes;
|
|
const size_t nMinChunkPerBlock = nChunk / nBlockForThisPeer;
|
|
const size_t nRemainderChunk = nChunk % nBlockForThisPeer;
|
|
|
|
// Distribute chunks to blocks
|
|
size_t nChunkForThisBlock;
|
|
size_t offsetForThisBlock;
|
|
if (peerLocalBlockIdx < nRemainderChunk) {
|
|
nChunkForThisBlock = nMinChunkPerBlock + 1;
|
|
offsetForThisBlock = (nMinChunkPerBlock + 1) * peerLocalBlockIdx;
|
|
} else {
|
|
nChunkForThisBlock = nMinChunkPerBlock;
|
|
offsetForThisBlock =
|
|
(nMinChunkPerBlock + 1) * nRemainderChunk + (peerLocalBlockIdx - nRemainderChunk) * nMinChunkPerBlock;
|
|
}
|
|
offsetForThisBlock *= chunkBytes;
|
|
|
|
// Calculate the size of the data for this block
|
|
size_t sizeForThisBlock = nChunkForThisBlock * chunkBytes;
|
|
const size_t lastChunkSize = size - nChunk * chunkBytes;
|
|
if (lastChunkSize > 0 && peerLocalBlockIdx == nBlockForThisPeer - 1) {
|
|
sizeForThisBlock += lastChunkSize;
|
|
}
|
|
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
|
|
smChans[peerIdx].relaxedSignal();
|
|
smChans[peerIdx].wait();
|
|
}
|
|
allGatherDeviceSyncer.sync(nBlocks);
|
|
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
|
|
smChans[peerIdx].get(offset + offsetForThisBlock, sizeForThisBlock, threadIdx.x, blockDim.x);
|
|
}
|
|
|
|
__device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
|
|
uint64_t size, size_t nBlocks) {
|
|
if (nRanksPerNode == 1) return;
|
|
if (blockIdx.x >= nBlocks) return;
|
|
|
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
const int nPeer = nRanksPerNode - 1;
|
|
|
|
if (tid < nPeer) {
|
|
smChans[tid].signal();
|
|
}
|
|
int waitStart = nBlocks * blockDim.x - nPeer;
|
|
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
|
|
smChans[tid - waitStart].wait();
|
|
}
|
|
allGatherDeviceSyncer.sync(nBlocks);
|
|
for (int i = 0; i < nPeer; ++i) {
|
|
int peerIdx = (i + rank) % nPeer;
|
|
const int remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
|
size_t offset = size * remoteRankLocalIndex;
|
|
smChans[peerIdx].get(offset, size, tid, blockDim.x * nBlocks);
|
|
}
|
|
}
|
|
|
|
// This is an allgather4 equivalent
|
|
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
|
|
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, int rank, int worldSize,
|
|
int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
|
|
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
|
// it is implemented as follows:
|
|
// Step 1: each node does a local allgather and concurrently,
|
|
// local GPU i exchange (piplineSize-1)/pipelineSize portion of their data with
|
|
// its cross-node neighbor (local GPU i on the other node) via IB
|
|
// Step 2: each node does a local allgather again with the data just received from its
|
|
// cross-node neighbor in step 1, and concurrently, exchange the rest of the data with
|
|
// its cross-node neighbor
|
|
// Step 3: each node does a local allgather for the last time with the rest of the data
|
|
|
|
int pipelineSize = pipelineDepth;
|
|
int peerRank = (rank + nRanksPerNode) % worldSize;
|
|
int peerNodeId = peerRank / nRanksPerNode;
|
|
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
|
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
|
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
|
|
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
|
|
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
|
|
const int startRankIndexInPeerNode = (peerRank / nRanksPerNode) * nRanksPerNode;
|
|
|
|
if (peerNodeId == rank / nRanksPerNode) {
|
|
localAllGatherSm(smChans, rank, nRanksPerNode, 0, 0, rankChunkSize, rankChunkSize, gridDim.x);
|
|
return;
|
|
}
|
|
|
|
constexpr size_t alignment = 128;
|
|
size_t step1Bytes = (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int);
|
|
step1Bytes = step1Bytes / alignment * alignment;
|
|
const size_t step2Bytes = nelemsPerGPU * sizeof(int) - step1Bytes;
|
|
|
|
// Step 1
|
|
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
|
|
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), step1Bytes);
|
|
}
|
|
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInLocalNode, 0, rankChunkSize, rankChunkSize,
|
|
nBlocksForLocalAllGather);
|
|
if (threadIdx.x == 0 && blockIdx.x == 0 && step1Bytes > 0) {
|
|
proxyChan.wait();
|
|
proxyChan.flush();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
// Step 2
|
|
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
|
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int) + step1Bytes, step2Bytes);
|
|
}
|
|
if (step1Bytes > 0)
|
|
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, 0, rankChunkSize, step1Bytes,
|
|
nBlocksForLocalAllGather);
|
|
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
|
proxyChan.wait();
|
|
proxyChan.flush();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
// Step 3
|
|
localAllGatherSm(smChans, rank, nRanksPerNode, startRankIndexInPeerNode, step1Bytes, rankChunkSize, step2Bytes,
|
|
nBlocksForLocalAllGather);
|
|
}
|
|
|
|
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
|
|
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff, TYPE* scratch,
|
|
int rank, int nRanksPerNode, int worldSize,
|
|
size_t nelems, // must be divisible by 3
|
|
int pipelineDepth) {
|
|
// this reduce-scatter algorithm works as follows:
|
|
// Step 1: each node does a local reduce-scatter on peer node data chunks with 1/pipeline portion of chunk data. For
|
|
// example, 2 nodes and each node has 2 ranks. rank 0 and rank 1 perform reduce-scatter on chunk 2 and chunk 3, with
|
|
// 1/pipeline portion of the data.
|
|
// Step 2: each node does a local reduce-scatter on peers data chunks with (pipeline-1)/pipeline portion of chunk
|
|
// data. Meanwhile, exchange the reduced data of the previous step with its cross-node neighbor (same local rank
|
|
// number on the other node) via IB. Then performs a reduce operation.
|
|
// Step 3: each node does a local reduce-scatter on local ranks, meanwhile exchange the reduced data of the previous
|
|
// step with its cross-node neighbor (same local rank number on the other node) via IB. Then performs a reduce
|
|
// operation.
|
|
int pipelineSize = pipelineDepth;
|
|
float nBlocksForReduceScatterRatio = 0.8;
|
|
const size_t chunkSize = nelems / worldSize;
|
|
const int peerRank = (rank + nRanksPerNode) % worldSize;
|
|
int peerNodeId = peerRank / nRanksPerNode;
|
|
int nBlocksForReduceScatter =
|
|
(int)(nBlocksForReduceScatterRatio * gridDim.x) / (nRanksPerNode - 1) * (nRanksPerNode - 1);
|
|
int isComm = (threadIdx.x == 0) && (blockIdx.x == nBlocksForReduceScatter);
|
|
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
|
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
|
|
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
|
if (peerNodeId == rank / nRanksPerNode) {
|
|
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
|
|
return;
|
|
}
|
|
|
|
// step 1: local reduce
|
|
int startChunkIndex = peerNodeId * nRanksPerNode;
|
|
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize / pipelineSize,
|
|
nBlocksForReduceScatter);
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
// step 2: local reduce and exchange data with neighbor
|
|
if (isComm) {
|
|
size_t offset = (peerRank * chunkSize) * sizeof(int);
|
|
// opposite side
|
|
proxyChan.putWithSignal(offset, (chunkSize / pipelineSize * sizeof(int)));
|
|
}
|
|
if (pipelineSize > 1)
|
|
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, chunkSize / pipelineSize, chunkSize,
|
|
(pipelineSize - 1) * chunkSize / pipelineSize, nBlocksForReduceScatter);
|
|
if (isComm) {
|
|
proxyChan.wait();
|
|
}
|
|
if (blockIdx.x >= nBlocksForReduceScatter) {
|
|
ibDeviceSyncer.sync(nBlocksRemain);
|
|
// reduce data received from peer to related rank
|
|
size_t offset = rank * chunkSize * sizeof(int);
|
|
int* dst = (int*)((char*)buff + offset);
|
|
int* src = (int*)((char*)scratch + offset);
|
|
vectorSum((TYPE*)dst, (TYPE*)src, chunkSize / pipelineSize, blockIdx.x - nBlocksForReduceScatter, nBlocksRemain);
|
|
}
|
|
if (isComm) {
|
|
proxyChan.flush();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
|
|
// step 3: local reduce and exchange data with neighbor
|
|
startChunkIndex = (rank / nRanksPerNode) * nRanksPerNode;
|
|
if (isComm && pipelineSize > 1) {
|
|
size_t offset = (peerRank * chunkSize + chunkSize / pipelineSize) * sizeof(int);
|
|
proxyChan.putWithSignal(offset, (pipelineSize - 1) * chunkSize / pipelineSize * sizeof(int));
|
|
}
|
|
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, startChunkIndex, 0, chunkSize, chunkSize,
|
|
nBlocksForReduceScatter);
|
|
if (isComm && pipelineSize > 1) {
|
|
proxyChan.wait();
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
// reduce to related rank, can not overlap since localReduceScatter also calculate the sum
|
|
size_t offset = (rank * chunkSize + chunkSize / pipelineSize) * sizeof(int);
|
|
int* dst = (int*)((char*)buff + offset);
|
|
int* src = (int*)((char*)scratch + offset);
|
|
if (pipelineSize > 1) vectorSum((TYPE*)dst, (TYPE*)src, (pipelineSize - 1) * chunkSize / pipelineSize);
|
|
if (isComm) {
|
|
proxyChan.flush();
|
|
}
|
|
}
|
|
|
|
extern "C" __global__ void __launch_bounds__(1024, 1) __global__
|
|
allreduce4(mscclpp::SmChannelDeviceHandle* smChans,
|
|
mscclpp::SimpleProxyChannelDeviceHandle* reduceScatterProxyChans,
|
|
mscclpp::SimpleProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
|
|
int nRanksPerNode, int worldSize, size_t nelems, int pipelineDepth) {
|
|
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
|
reduceScatterSm(smChans, reduceScatterProxyChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
|
|
pipelineDepth);
|
|
deviceSyncer.sync(gridDim.x);
|
|
allGatherSm(smChans, allGatherProxyChans, rank, worldSize, nRanksPerNode, nelems / worldSize, pipelineDepth);
|
|
}
|
|
|
|
// allreduce 5 for 2-nodes
|
|
extern "C" __global__ void __launch_bounds__(1024, 1)
|
|
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff,
|
|
TYPE* scratch, TYPE* putBuff, TYPE* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
|
size_t nelems) {
|
|
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
|
// This version of allreduce only works for single nodes
|
|
const int nPeersInNode = nRanksPerNode - 1;
|
|
const int nPkts = nelems / 2;
|
|
const int nelemsPerLocalRank = nelems / nRanksPerNode;
|
|
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
|
|
const int localRankId = rank % nRanksPerNode;
|
|
// flag for packets. Initially 1
|
|
const uint32_t flag = (uint32_t)globalFlag;
|
|
// thread block & channel info
|
|
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
|
|
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
|
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
|
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
|
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
|
|
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
|
|
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
|
// double buffering
|
|
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
|
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
|
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
|
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
|
size_t scratchResultOffset =
|
|
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
|
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
|
|
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
|
|
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
|
|
|
|
// step 1: write to scratch buffer
|
|
if (nRanksPerNode > 1) {
|
|
smChan.putPackets(scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer,
|
|
flag);
|
|
}
|
|
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
|
|
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
|
uint2 data = make_uint2(0, 0);
|
|
for (int index = 0; index < nPeersInNode; index++) {
|
|
const int remoteRank = index < localRankId ? index : index + 1;
|
|
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
|
|
uint2 val = dstPkt[idx].read(flag);
|
|
data = add_vectors<TYPE>(val, data);
|
|
}
|
|
data = add_vectors<TYPE>(data, src[idx]);
|
|
putPkt[idx].write(data.x, data.y, flag);
|
|
dst[idx] = data;
|
|
}
|
|
deviceSyncer.sync(gridDim.x);
|
|
// step 3. send local reduced data to remote node.
|
|
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
|
proxyChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
|
if ((flag & 63) == 0) {
|
|
proxyChan.flush();
|
|
}
|
|
}
|
|
// step 4. try to read the data from scratch buffer and write to local peers
|
|
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
|
|
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
|
uint2 res = dst[idx];
|
|
uint2 val = dstPkt[idx].read(flag);
|
|
res = add_vectors<TYPE>(res, val);
|
|
|
|
mscclpp::LLPacket packet;
|
|
packet.data1 = res.x;
|
|
packet.flag1 = flag;
|
|
packet.data2 = res.y;
|
|
packet.flag2 = flag;
|
|
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
|
|
for (int index = 0; index < nPeersInNode; index++) {
|
|
smChans[index].write(offset, packet);
|
|
}
|
|
dst[idx] = res;
|
|
}
|
|
|
|
// step 5: get data result from scratch buffer
|
|
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
|
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
|
|
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
|
|
if (nRanksPerNode > 1) {
|
|
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
|
|
idx += blockDim.x * nBlocksPerPeer) {
|
|
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
|
result[idx] = data;
|
|
}
|
|
}
|
|
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
|
globalFlag += 1;
|
|
}
|
|
}
|
|
|
|
// -------------------------------------------
|
|
// AllReduce6
|
|
// NVLS
|
|
// -------------------------------------------
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
|
|
|
// Barrier among all devices
|
|
// Should be called by all threads on all devices
|
|
// Assumes \p num_threads_per_block >= \p num_ranks
|
|
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
|
|
int block_id, int num_blocks, int num_ranks) {
|
|
// wait for every device
|
|
if (block_id == 0) {
|
|
// 1 less than the num_ranks because there is no semaphore for self
|
|
if (thread_id < num_ranks - 1) {
|
|
semaphores[thread_id].signal();
|
|
semaphores[thread_id].wait();
|
|
}
|
|
}
|
|
|
|
// wait for every thread in every block on this device
|
|
deviceSyncer.sync(num_blocks);
|
|
}
|
|
|
|
// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
|
|
template <typename DataType = float, int kVecSize = 8>
|
|
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
|
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
|
|
int num_ranks, size_t num_elements) {
|
|
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
|
|
int tid = threadIdx.x;
|
|
int bid = blockIdx.x;
|
|
int num_threads_per_block = blockDim.x;
|
|
int num_blocks = gridDim.x;
|
|
|
|
// start with a barrier to ensure all devices have written their values
|
|
// to their own memory (that is part of the multicast memory)
|
|
// before reading them in this kernel
|
|
barrier(semaphores, tid, bid, num_blocks, num_ranks);
|
|
|
|
// every device loads, reduces, and stores a partition of the multicast memory
|
|
int rank_start = ((int64_t)num_elements * (int64_t)my_rank) / (int64_t)num_ranks;
|
|
int rank_end = ((int64_t)num_elements * (int64_t)(my_rank + 1)) / (int64_t)num_ranks;
|
|
|
|
int thread_offset = (bid * num_threads_per_block + tid) * kVecSize;
|
|
int thread_step = (num_threads_per_block * num_blocks) * kVecSize; // number of threads * vector size
|
|
|
|
for (int idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
|
|
if constexpr (std::is_same_v<DataType, float> && (kVecSize == 4)) {
|
|
uint4 val; // fits 4 float elements
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
|
|
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 2)) {
|
|
uint2 val; // fits 2 float elements
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
|
|
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 1)) {
|
|
uint1 val; // fits 1 float element
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
|
|
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 8)) {
|
|
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
|
|
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 4)) {
|
|
uint2 val; // fits 4 cutlass::half_t elements; i.e., 2 half2 elements
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
|
|
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 2)) {
|
|
uint1 val; // fits 2 cutlass::half_t elements; i.e., 1 half2 element
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
|
|
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
|
|
} else {
|
|
// not supported: cannot use static_assert because of the way TYPE is handled in this file
|
|
assert(false); // Unsupported data type and vector size combination
|
|
}
|
|
}
|
|
|
|
// end with a barrier to ensure all devices can now read their values
|
|
// from their own memory (that is part of the multicast memory)
|
|
// after writing them in this kernel
|
|
barrier(semaphores, tid, bid, num_blocks, num_ranks);
|
|
}
|
|
|
|
extern "C" __global__ void __launch_bounds__(1024, 1)
|
|
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
|
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
|
|
size_t vector_size) {
|
|
if (vector_size == 8) {
|
|
allreduce6_helper<TYPE, 8>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
|
|
} else if (vector_size == 4) {
|
|
allreduce6_helper<TYPE, 4>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
|
|
} else if (vector_size == 2) {
|
|
allreduce6_helper<TYPE, 2>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
|
|
} else {
|
|
allreduce6_helper<TYPE, 1>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
|
|
}
|
|
}
|
|
#endif
|