mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
@@ -18,18 +18,6 @@
|
||||
|
||||
static int nranksPerNode = 8;
|
||||
|
||||
// Propagate errors up
|
||||
|
||||
#define MSCCLPPCHECK(call) \
|
||||
do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
printf("Failure at %s:%d -> %s\n", __FILE__, __LINE__, mscclppGetErrorString(res)); \
|
||||
return res; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Check CUDA RT calls
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
@@ -54,8 +42,7 @@ template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> constProxyChans[16];
|
||||
|
||||
__device__ void allgather0(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, int world_size,
|
||||
int remoteRank, size_t nelemsPerGPU) {
|
||||
__device__ void allgather0(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, size_t nelemsPerGPU) {
|
||||
// this allgather is really simple and implemented as an alltoall
|
||||
|
||||
// this thread's role is a sender role
|
||||
@@ -70,8 +57,8 @@ __device__ void allgather0(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan,
|
||||
if ((threadIdx.x % 32) == 0) proxyChan.wait();
|
||||
}
|
||||
|
||||
__device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, int world_size,
|
||||
int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) {
|
||||
__device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, int nranksPerNode,
|
||||
int remoteRank, uint64_t offset, uint64_t size) {
|
||||
// this allgather algorithm works as follows:
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
|
||||
// and waits for data from GPU rank (i-1) % nranksPerNode
|
||||
@@ -91,9 +78,9 @@ __device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyCh
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void allgather1(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, int world_size,
|
||||
int nranksPerNode, int remoteRank, size_t nelemsPerGPU) {
|
||||
localAllGather(proxyChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
__device__ void allgather1(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan, int rank, int nranksPerNode,
|
||||
int remoteRank, size_t nelemsPerGPU) {
|
||||
localAllGather(proxyChan, rank, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode)
|
||||
if ((threadIdx.x % 32) == 0) proxyChan.flush();
|
||||
@@ -116,7 +103,7 @@ __device__ void allgather2(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan,
|
||||
// Step 1
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(proxyChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
localAllGather(proxyChan, rank, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
// cross-node exchange
|
||||
@@ -134,7 +121,7 @@ __device__ void allgather2(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan,
|
||||
// local allgather
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(proxyChan, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
localAllGather(proxyChan, rank, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
@@ -152,7 +139,7 @@ __device__ void allgather2(DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan,
|
||||
// Step 3
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(proxyChan, rank, world_size, nranksPerNode, remoteRank,
|
||||
localAllGather(proxyChan, rank, nranksPerNode, remoteRank,
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
@@ -170,9 +157,9 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = constProxyChans[warpId];
|
||||
|
||||
if (kernel == 0)
|
||||
allgather0(proxyChan, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
allgather0(proxyChan, rank, nelemsPerGPU);
|
||||
else if (kernel == 1)
|
||||
allgather1(proxyChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
allgather1(proxyChan, rank, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 2)
|
||||
allgather2(proxyChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
}
|
||||
@@ -388,7 +375,6 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
ip_port = (char*)parsedArgs["ip_port"].c_str();
|
||||
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
CUDACHECK(cudaSetDevice(cudaNum));
|
||||
|
||||
@@ -452,19 +438,19 @@ int main(int argc, const char* argv[]) {
|
||||
if (rank == 0) printf("Capturing %d iterations of the kernel in a CUDA graph\n", cudagraphiter);
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
|
||||
CUDACHECK(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
|
||||
for (int i = 0; i < cudagraphiter; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
CUDACHECK(cudaStreamEndCapture(stream, &graph));
|
||||
CUDACHECK(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
|
||||
|
||||
int cudagraphwarmup = 10;
|
||||
if (rank == 0)
|
||||
printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup,
|
||||
cudagraphiter);
|
||||
for (int i = 0; i < cudagraphwarmup; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
CUDACHECK(cudaGraphLaunch(instance, stream));
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
@@ -477,7 +463,7 @@ int main(int argc, const char* argv[]) {
|
||||
double t0, t1, ms, time_in_us;
|
||||
t0 = getTime();
|
||||
for (int i = 0; i < cudagraphlaunch; ++i) {
|
||||
cudaGraphLaunch(instance, stream);
|
||||
CUDACHECK(cudaGraphLaunch(instance, stream));
|
||||
}
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user