mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
fix
This commit is contained in:
@@ -13,47 +13,121 @@
|
||||
#define ALIGN 4
|
||||
__constant__ mscclppDevConn_t constDevConns[16];
|
||||
|
||||
__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, int nelemsPerGPU)
|
||||
__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is really simple and implemented as an alltoall
|
||||
|
||||
// this thread's role is a sender role
|
||||
// put your data asynchronously
|
||||
devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
|
||||
__syncthreads();
|
||||
// push with flag and sync to make sure the data is received
|
||||
devConn.signal();
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.flush();
|
||||
|
||||
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
|
||||
devConn.wait();
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, int nelemsPerGPU)
|
||||
__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
uint64_t offset, uint64_t size)
|
||||
{
|
||||
// this allgather algorithm works as follows:
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % world_size
|
||||
// Step 2: GPU rank i waits for data from GPU rank (i+2) % world_size
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
|
||||
// and waits for data from GPU rank (i-1) % nranksPerNode
|
||||
// Step 2: GPU rank i sends data to GPU rank (i+2) % nranksPerNode
|
||||
// ...
|
||||
// This order is much better for DMA engine for NVLinks
|
||||
|
||||
for (int i = 1; i < world_size; i++) {
|
||||
__syncthreads();
|
||||
if (remoteRank != ((rank + i) % world_size))
|
||||
continue;
|
||||
// put your data to GPU (rank+i) % world_size and signal all in one call
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(offset, size);
|
||||
}
|
||||
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
// all connections wait for the signal from the sender
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
|
||||
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
// only use a single thread from each warp
|
||||
if (threadIdx.x % 32 != 0)
|
||||
return;
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
|
||||
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
// it is implemented as follows:
|
||||
// Step 1: each node does a local allgather and concurrently,
|
||||
// local GPU i exchange (piplineSize-1)/pipelineSize portion of their data with
|
||||
// its cross-node neighbor (local GPU i on the other node) via IB
|
||||
// Step 2: each node does a local allgather again with the data just received from its
|
||||
// cross-node neighbor in step 1, and concurrently, exchange the rest of the data with
|
||||
// its cross-node neighbor
|
||||
// Step 3: each node does a local allgather for the last time with the rest of the data
|
||||
|
||||
int pipelineSize = 3;
|
||||
|
||||
// Step 1
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
// cross-node exchange
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Step 2
|
||||
// local allgather
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
// cross-node exchange
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) *
|
||||
sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Step 3
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank,
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel)
|
||||
{
|
||||
// find the mapping between remoteRank and devConns
|
||||
int warpId = threadIdx.x / 32;
|
||||
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
|
||||
@@ -63,7 +137,9 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
|
||||
if (kernel == 0)
|
||||
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 1)
|
||||
allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 2)
|
||||
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
}
|
||||
|
||||
void AllGatherGetCollByteCount(size_t* sendcount, size_t* recvcount, size_t* paramcount, size_t* sendInplaceOffset,
|
||||
@@ -78,19 +154,19 @@ void AllGatherGetCollByteCount(size_t* sendcount, size_t* recvcount, size_t* par
|
||||
}
|
||||
|
||||
testResult_t AllGatherInitData(struct threadArgs* args, int in_place) {
|
||||
size_t sendcount = args->sendBytes;
|
||||
// size_t sendcount = args->sendBytes;
|
||||
size_t recvcount = args->expectedBytes;
|
||||
int nranks = args->totalProcs;
|
||||
// int nranks = args->totalProcs;
|
||||
|
||||
CUDACHECK(cudaSetDevice(args->gpus[0]));
|
||||
int rank = args->proc;
|
||||
CUDACHECK(cudaMemset(args->recvbuffs[0], 0, args->expectedBytes));
|
||||
void* data = in_place ? ((char*)args->recvbuffs[0]) + rank * args->sendBytes : args->sendbuffs[0];
|
||||
// void* data = in_place ? ((char*)args->recvbuffs[0]) + rank * args->sendBytes : args->sendbuffs[0];
|
||||
|
||||
int* dataHost = new int[recvcount];
|
||||
for (int i = 0; i < static_cast<int>(recvcount); i++) {
|
||||
int val = i + 1;
|
||||
if (i / args->ranksPerNode == rank) {
|
||||
if (i / args->nranksPerNode == rank) {
|
||||
dataHost[i] = val;
|
||||
} else {
|
||||
dataHost[i] = 0;
|
||||
@@ -111,10 +187,11 @@ void AllGatherGetBw(size_t count, int typesize, double sec, double* algBw, doubl
|
||||
*busBw = baseBw * factor;
|
||||
}
|
||||
|
||||
testResult_t AllGatherRunColl(void* sendbuff, void* recvbuff, size_t count, mscclppComm_t comm, cudaStream_t stream)
|
||||
testResult_t AllGatherRunColl(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int worldSize = comm->nRanks;
|
||||
kernel<<<1, 32 * (worldSize - 1), 0, stream>>>(comm->rank , worldSize, count / sizeof(int), 0);
|
||||
kernel<<<1, 32 * (worldSize - 1), 0, stream>>>(comm->rank, worldSize, nranksPerNode, count / sizeof(int), 1);
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
@@ -129,9 +206,9 @@ void AllGatherGetBuffSize(size_t *sendcount, size_t *recvcount, size_t count, in
|
||||
testResult_t AllGatherRunTest(struct threadArgs* args)
|
||||
{
|
||||
args->collTest = &allGatherTest;
|
||||
mscclppDevConn_t* devConns;
|
||||
mscclppDevConn_t* devConns;
|
||||
int nCons;
|
||||
MSCCLPPCHECK(mscclppGetAllDeviceConnections(args->comms[0], &devConns, &nCons));
|
||||
MSCCLPPCHECK(mscclppGetAllDeviceConnections(args->comm, &devConns, &nCons));
|
||||
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * nCons));
|
||||
TESTCHECK(TimeTest(args));
|
||||
return testSuccess;
|
||||
|
||||
116
tests/common.cu
116
tests/common.cu
@@ -29,13 +29,14 @@ static size_t maxBytes = 32*1024*1024;
|
||||
static size_t stepBytes = 1*1024*1024;
|
||||
static size_t stepFactor = 1;
|
||||
static int datacheck = 1;
|
||||
static int warmup_iters = 5;
|
||||
static int warmup_iters = 10;
|
||||
static int iters = 20;
|
||||
static int timeout = 0;
|
||||
static int report_cputime = 0;
|
||||
// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
|
||||
static int average = 1;
|
||||
static std::string ip_port;
|
||||
static int cudaGraphLaunches = 10;
|
||||
|
||||
#define NUM_BLOCKS 32
|
||||
|
||||
@@ -117,54 +118,46 @@ testResult_t startColl(struct threadArgs* args, int in_place, int iter) {
|
||||
size_t steps = totalnbytes ? args->maxbytes / totalnbytes : 1;
|
||||
size_t shift = totalnbytes * (iter % steps);
|
||||
|
||||
for (int i = 0; i < args->nGpus; i++) {
|
||||
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
|
||||
char* recvBuff = ((char*)args->recvbuffs[i]) + shift;
|
||||
char* sendBuff = ((char*)args->sendbuffs[i]) + shift;
|
||||
int rank = ((args->proc * args->nThreads + args->thread) * args->nGpus);
|
||||
char* recvBuff = ((char*)args->recvbuffs[0]) + shift;
|
||||
char* sendBuff = ((char*)args->sendbuffs[0]) + shift;
|
||||
|
||||
TESTCHECK(args->collTest->runColl((void*)(in_place ? recvBuff + args->sendInplaceOffset * rank : sendBuff),
|
||||
(void*)(in_place ? recvBuff + args->recvInplaceOffset * rank : recvBuff), count,
|
||||
args->comms[0], args->streams[i]));
|
||||
}
|
||||
TESTCHECK(args->collTest->runColl((void*)(in_place ? recvBuff + args->sendInplaceOffset * rank : sendBuff),
|
||||
(void*)(in_place ? recvBuff + args->recvInplaceOffset * rank : recvBuff),
|
||||
args->nranksPerNode, count, args->comm, args->stream));
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
testResult_t testStreamSynchronize(int ngpus, cudaStream_t* streams)
|
||||
testResult_t testStreamSynchronize(cudaStream_t stream)
|
||||
{
|
||||
cudaError_t cudaErr;
|
||||
int remaining = ngpus;
|
||||
int* done = (int*)malloc(sizeof(int) * ngpus);
|
||||
memset(done, 0, sizeof(int) * ngpus);
|
||||
timer tim;
|
||||
|
||||
while (remaining) {
|
||||
int idle = 1;
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
if (done[i])
|
||||
continue;
|
||||
|
||||
cudaErr = cudaStreamQuery(streams[i]);
|
||||
if (cudaErr == cudaSuccess) {
|
||||
done[i] = 1;
|
||||
remaining--;
|
||||
idle = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cudaErr != cudaErrorNotReady)
|
||||
CUDACHECK(cudaErr);
|
||||
while (true) {
|
||||
cudaErr = cudaStreamQuery(stream);
|
||||
if (cudaErr == cudaSuccess) {
|
||||
break;
|
||||
}
|
||||
|
||||
// We might want to let other threads (including NCCL threads) use the CPU.
|
||||
if (idle)
|
||||
sched_yield();
|
||||
if (cudaErr != cudaErrorNotReady)
|
||||
CUDACHECK(cudaErr);
|
||||
|
||||
double delta = tim.elapsed();
|
||||
if (delta > timeout && timeout > 0) {
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
printf("%s: Test timeout (%ds) %s:%d\n", hostname, timeout, __FILE__, __LINE__);
|
||||
return testTimeout;
|
||||
}
|
||||
|
||||
// We might want to let other threads (including MSCCLPP threads) use the CPU.
|
||||
sched_yield();
|
||||
}
|
||||
free(done);
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
testResult_t completeColl(struct threadArgs* args) {
|
||||
TESTCHECK(testStreamSynchronize(args->nGpus, args->streams));
|
||||
TESTCHECK(testStreamSynchronize(args->stream));
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
@@ -255,24 +248,26 @@ testResult_t BenchTime(struct threadArgs* args, int in_place) {
|
||||
// Performance Benchmark
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t graphExec;
|
||||
CUDACHECK(cudaStreamBeginCapture(args->streams[0], cudaStreamCaptureModeGlobal));
|
||||
CUDACHECK(cudaStreamBeginCapture(args->stream, cudaStreamCaptureModeGlobal));
|
||||
timer tim;
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
TESTCHECK(startColl(args, in_place, iter));
|
||||
}
|
||||
CUDACHECK(cudaStreamEndCapture(args->streams[0], &graph));
|
||||
CUDACHECK(cudaStreamEndCapture(args->stream, &graph));
|
||||
CUDACHECK(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
|
||||
|
||||
// Launch the graph
|
||||
Barrier(args);
|
||||
tim.reset();
|
||||
CUDACHECK(cudaGraphLaunch(graphExec, args->streams[0]));
|
||||
for (int l = 0; l < cudaGraphLaunches; ++l) {
|
||||
CUDACHECK(cudaGraphLaunch(graphExec, args->stream));
|
||||
}
|
||||
|
||||
double cputimeSec = tim.elapsed()/(iters);
|
||||
TESTCHECK(completeColl(args));
|
||||
|
||||
double deltaSec = tim.elapsed();
|
||||
deltaSec = deltaSec/(iters);
|
||||
deltaSec = deltaSec/(iters)/(cudaGraphLaunches);
|
||||
Allreduce(args, &deltaSec, average);
|
||||
|
||||
CUDACHECK(cudaGraphExecDestroy(graphExec));
|
||||
@@ -383,13 +378,13 @@ testResult_t setupMscclppConnections(int rank, int worldSize, int ranksPerNode,
|
||||
testResult_t threadRunTests(struct threadArgs* args)
|
||||
{
|
||||
PRINT("# Setting up the connection in MSCCL++\n");
|
||||
TESTCHECK(setupMscclppConnections(args->proc, args->totalProcs, args->ranksPerNode, args->comms[0],
|
||||
TESTCHECK(setupMscclppConnections(args->proc, args->totalProcs, args->nranksPerNode, args->comm,
|
||||
args->recvbuffs[0], args->maxbytes));
|
||||
PRINT("# Launching MSCCL++ proxy threads\n");
|
||||
MSCCLPPCHECK(mscclppProxyLaunch(args->comms[0]));
|
||||
MSCCLPPCHECK(mscclppProxyLaunch(args->comm));
|
||||
TESTCHECK(mscclppTestEngine.runTest(args));
|
||||
PRINT("Stopping MSCCL++ proxy threads\n");
|
||||
MSCCLPPCHECK(mscclppProxyStop(args->comms[0]));
|
||||
MSCCLPPCHECK(mscclppProxyStop(args->comm));
|
||||
return testSuccess;
|
||||
}
|
||||
|
||||
@@ -411,6 +406,7 @@ int main(int argc, char* argv[]) {
|
||||
{"warmup_iters", required_argument, 0, 'w'},
|
||||
{"check", required_argument, 0, 'c'},
|
||||
{"timeout", required_argument, 0, 'T'},
|
||||
{"cudagraph", required_argument, 0, 'G'},
|
||||
{"report_cputime", required_argument, 0, 'C'},
|
||||
{"average", required_argument, 0, 'a'},
|
||||
{"ip_port", required_argument, 0, 'P'},
|
||||
@@ -460,6 +456,9 @@ int main(int argc, char* argv[]) {
|
||||
case 'T':
|
||||
timeout = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'G':
|
||||
cudaGraphLaunches = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
case 'C':
|
||||
report_cputime = strtol(optarg, NULL, 0);
|
||||
break;
|
||||
@@ -481,6 +480,7 @@ int main(int argc, char* argv[]) {
|
||||
"[-w,--warmup_iters <warmup iteration count>] \n\t"
|
||||
"[-c,--check <0/1>] \n\t"
|
||||
"[-T,--timeout <time in seconds>] \n\t"
|
||||
"[-G,--cudagraph <num graph launches>] \n\t"
|
||||
"[-C,--report_cputime <0/1>] \n\t"
|
||||
"[-a,--average <0/1/2/3> report average iteration time <0=RANK0/1=AVG/2=MIN/3=MAX>] \n\t"
|
||||
"[-P,--ip_port <ip port for bootstrap>] \n\t"
|
||||
@@ -507,7 +507,7 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
testResult_t run() {
|
||||
int totalProcs = 1, proc = 0;
|
||||
int ranksPerNode = 0, localRank = 0;
|
||||
int nranksPerNode = 0, localRank = 0;
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
|
||||
@@ -516,16 +516,16 @@ testResult_t run() {
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &proc);
|
||||
MPI_Comm shmcomm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
|
||||
MPI_Comm_size(shmcomm, &ranksPerNode);
|
||||
MPI_Comm_size(shmcomm, &nranksPerNode);
|
||||
MPI_Comm_free(&shmcomm);
|
||||
localRank = proc % ranksPerNode;
|
||||
localRank = proc % nranksPerNode;
|
||||
#endif
|
||||
is_main_thread = is_main_proc = (proc == 0) ? 1 : 0;
|
||||
is_main_thread = is_main_proc = (proc == 0) ? 1 : 0;
|
||||
|
||||
PRINT("# minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d validation: %d ip port: %s\n", minBytes,
|
||||
maxBytes, (stepFactor > 1) ? stepFactor : stepBytes, (stepFactor > 1) ? "factor" : "bytes", warmup_iters, iters,
|
||||
datacheck, ip_port.c_str());
|
||||
PRINT("# minBytes %ld maxBytes %ld step: %ld(%s) warmup iters: %d iters: %d validation: %d ip port: %s graph: %d\n",
|
||||
minBytes, maxBytes, (stepFactor > 1) ? stepFactor : stepBytes, (stepFactor > 1) ? "factor" : "bytes",
|
||||
warmup_iters, iters, datacheck, ip_port.c_str(), cudaGraphLaunches);
|
||||
PRINT("#\n");
|
||||
PRINT("# Using devices\n");
|
||||
|
||||
@@ -537,9 +537,11 @@ testResult_t run() {
|
||||
int cudaDev = localRank;
|
||||
int rank = proc;
|
||||
cudaDeviceProp prop;
|
||||
char busIdChar[] = "00000000:00:00.0";
|
||||
CUDACHECK(cudaGetDeviceProperties(&prop, cudaDev));
|
||||
len += snprintf(line + len, MAX_LINE - len, "# Rank %2d Pid %6d on %10s device %2d [0x%02x] %s\n", rank, getpid(),
|
||||
hostname, cudaDev, prop.pciBusID, prop.name);
|
||||
CUDACHECK(cudaDeviceGetPCIBusId(busIdChar, sizeof(busIdChar), cudaDev));
|
||||
len += snprintf(line + len, MAX_LINE - len, "# Rank %2d Pid %6d on %10s device %2d [%s] %s\n", rank, getpid(),
|
||||
hostname, cudaDev, busIdChar, prop.name);
|
||||
maxMem = std::min(maxMem, prop.totalGlobalMem);
|
||||
|
||||
#if MSCCLPP_USE_MPI_FOR_TESTS
|
||||
@@ -577,8 +579,8 @@ testResult_t run() {
|
||||
PRINT("#\n");
|
||||
PRINT("# Initializing MSCCL++\n");
|
||||
|
||||
mscclppComm_t comms;
|
||||
MSCCLPPCHECK(mscclppCommInitRank(&comms, totalProcs, ip_port.c_str(), rank));
|
||||
mscclppComm_t comm;
|
||||
MSCCLPPCHECK(mscclppCommInitRank(&comm, totalProcs, ip_port.c_str(), rank));
|
||||
|
||||
int error = 0;
|
||||
double bw = 0.0;
|
||||
@@ -588,14 +590,14 @@ testResult_t run() {
|
||||
|
||||
fflush(stdout);
|
||||
|
||||
struct testThread thread = {0};
|
||||
struct testThread thread;
|
||||
|
||||
thread.args.minbytes = minBytes;
|
||||
thread.args.maxbytes = maxBytes;
|
||||
thread.args.stepbytes = stepBytes;
|
||||
thread.args.stepfactor = stepFactor;
|
||||
thread.args.localRank = localRank;
|
||||
thread.args.ranksPerNode = ranksPerNode;
|
||||
thread.args.nranksPerNode = nranksPerNode;
|
||||
|
||||
thread.args.totalProcs = totalProcs;
|
||||
thread.args.proc = proc;
|
||||
@@ -606,8 +608,8 @@ testResult_t run() {
|
||||
thread.args.sendbuffs = &sendbuff;
|
||||
thread.args.recvbuffs = &recvbuff;
|
||||
thread.args.expected = &expected;
|
||||
thread.args.comms = &comms;
|
||||
thread.args.streams = &stream;
|
||||
thread.args.comm = comm;
|
||||
thread.args.stream = stream;
|
||||
|
||||
thread.args.errors = &error;
|
||||
thread.args.bw = &bw;
|
||||
@@ -618,9 +620,7 @@ testResult_t run() {
|
||||
thread.func = threadRunTests;
|
||||
TESTCHECK(thread.func(&thread.args));
|
||||
|
||||
// Wait for other threads and accumulate stats and errors
|
||||
TESTCHECK(thread.ret);
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comms));
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comm));
|
||||
|
||||
// Free off CUDA allocated memory
|
||||
if (sendbuff)
|
||||
|
||||
@@ -70,7 +70,8 @@ struct testColl {
|
||||
size_t count, int nranks);
|
||||
testResult_t (*initData)(struct threadArgs* args, int in_place);
|
||||
void (*getBw)(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks);
|
||||
testResult_t (*runColl)(void* sendbuff, void* recvbuff, size_t count, mscclppComm_t comm, cudaStream_t stream);
|
||||
testResult_t (*runColl)(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
|
||||
cudaStream_t stream);
|
||||
};
|
||||
|
||||
struct testEngine
|
||||
@@ -97,14 +98,14 @@ struct threadArgs
|
||||
int nGpus;
|
||||
int* gpus;
|
||||
int localRank;
|
||||
int ranksPerNode;
|
||||
int nranksPerNode;
|
||||
void** sendbuffs;
|
||||
size_t sendBytes;
|
||||
size_t sendInplaceOffset;
|
||||
void** recvbuffs;
|
||||
size_t recvInplaceOffset;
|
||||
mscclppComm_t* comms;
|
||||
cudaStream_t* streams;
|
||||
mscclppComm_t comm;
|
||||
cudaStream_t stream;
|
||||
|
||||
void** expected;
|
||||
size_t expectedBytes;
|
||||
@@ -123,7 +124,6 @@ struct testThread
|
||||
pthread_t thread;
|
||||
threadFunc_t func;
|
||||
struct threadArgs args;
|
||||
testResult_t ret;
|
||||
};
|
||||
|
||||
// Provided by common.cu
|
||||
|
||||
Reference in New Issue
Block a user