Merge branch 'main' into binyli/merge-main

This commit is contained in:
Binyang Li
2023-05-10 06:02:22 +00:00
7 changed files with 231 additions and 152 deletions

View File

@@ -155,7 +155,7 @@ TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
MSCLLPPTESTSOBJSDIR:= $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)
MSCLLPPTESTBINFILESLIST := allgather_test allreduce_test ring_send_recv_test
MSCLLPPTESTBINFILESLIST := allgather_test allreduce_test sendrecv_test
MSCLLPPTESTBINS := $(MSCLLPPTESTBINFILESLIST:%=$(BUILDDIR)/$(BINDIR)/$(TESTSDIR)/%_perf)
INCLUDE := -Isrc -Isrc/include

View File

@@ -195,8 +195,8 @@ testResult_t AllGatherRunColl(void* sendbuff, void* recvbuff, int nranksPerNode,
return testSuccess;
}
struct testColl allGatherTest = {"AllGather", AllGatherGetCollByteCount, AllGatherInitData, AllGatherGetBw,
AllGatherRunColl};
struct testColl allGatherTest = {"AllGather", AllGatherGetCollByteCount, defaultInitColl, AllGatherInitData,
AllGatherGetBw, AllGatherRunColl};
void AllGatherGetBuffSize(size_t* sendcount, size_t* recvcount, size_t count, int nranks)
{

View File

@@ -204,8 +204,8 @@ testResult_t AllReduceRunColl(void* sendbuff, void* recvbuff, int nranksPerNode,
return testSuccess;
}
struct testColl allReduceTest = {"AllReduce", AllReduceGetCollByteCount, AllReduceInitData, AllReduceGetBw,
AllReduceRunColl};
struct testColl allReduceTest = {"AllReduce", AllReduceGetCollByteCount, defaultInitColl, AllReduceInitData,
AllReduceGetBw, AllReduceRunColl};
testResult_t AllReduceSetupMscclppConnections(struct testArgs* args)
{

View File

@@ -212,11 +212,8 @@ template <typename T> void Allreduce(struct testArgs* args, T* value, int averag
*value = accumulator;
}
testResult_t CheckData(struct testArgs* args, int in_place, int64_t* wrongElts)
testResult_t CheckData(struct testArgs* args, int64_t* wrongElts)
{
if (in_place == 0) {
return testInternalError;
}
size_t count = args->expectedBytes / sizeof(int);
int* dataHostRecv = new int[count];
@@ -226,6 +223,8 @@ testResult_t CheckData(struct testArgs* args, int in_place, int64_t* wrongElts)
for (size_t i = 0; i < count; i++) {
if (dataHostRecv[i] != dataHostExpected[i]) {
// PRINT("Error: dataHostRecv[%ld] = %d, dataHostExpected[%ld] = %d\n", i, dataHostRecv[i], i,
// dataHostExpected[i]);
*wrongElts += 1;
}
}
@@ -299,7 +298,7 @@ testResult_t BenchTime(struct testArgs* args, int in_place)
CUDACHECK(cudaGraphExecDestroy(graphExec));
CUDACHECK(cudaGraphDestroy(graph));
TESTCHECK(CheckData(args, in_place, &wrongElts));
TESTCHECK(CheckData(args, &wrongElts));
// aggregate delta from all threads and procs
long long wrongElts1 = wrongElts;
@@ -316,6 +315,9 @@ testResult_t BenchTime(struct testArgs* args, int in_place)
} else {
sprintf(timeStr, "%7.2f", timeUsec);
}
if (!in_place) {
PRINT(" ");
}
if (args->reportErrors) {
PRINT(" %7s %6.2f %6.2f %5g", timeStr, algBw, busBw, (double)wrongElts);
} else {
@@ -327,7 +329,7 @@ testResult_t BenchTime(struct testArgs* args, int in_place)
return testSuccess;
}
void setupArgs(size_t size, struct testArgs* args)
testResult_t setupArgsAndInit(size_t size, struct testArgs* args)
{
int nranks = args->totalProcs;
size_t count, sendCount, recvCount, paramCount, sendInplaceOffset, recvInplaceOffset;
@@ -343,6 +345,8 @@ void setupArgs(size_t size, struct testArgs* args)
args->expectedBytes = recvCount * typeSize;
args->sendInplaceOffset = sendInplaceOffset * typeSize;
args->recvInplaceOffset = recvInplaceOffset * typeSize;
return args->collTest->initColl();
}
testResult_t TimeTest(struct testArgs* args)
@@ -351,7 +355,7 @@ testResult_t TimeTest(struct testArgs* args)
TESTCHECK(Barrier(args));
// Warm-up for large size
setupArgs(args->maxbytes, args);
TESTCHECK(setupArgsAndInit(args->maxbytes, args));
TESTCHECK(args->collTest->initData(args, 1));
for (int iter = 0; iter < warmup_iters; iter++) {
TESTCHECK(startColl(args, 1, iter));
@@ -359,7 +363,7 @@ testResult_t TimeTest(struct testArgs* args)
TESTCHECK(completeColl(args));
// Warm-up for small size
setupArgs(args->minbytes, args);
TESTCHECK(setupArgsAndInit(args->minbytes, args));
for (int iter = 0; iter < warmup_iters; iter++) {
TESTCHECK(startColl(args, 1, iter));
}
@@ -374,11 +378,9 @@ testResult_t TimeTest(struct testArgs* args)
// Benchmark
for (size_t size = args->minbytes; size <= args->maxbytes;
size = ((args->stepfactor > 1) ? size * args->stepfactor : size + args->stepbytes)) {
setupArgs(size, args);
TESTCHECK(setupArgsAndInit(size, args));
PRINT("%12li %12li", max(args->sendBytes, args->expectedBytes), args->nbytes / sizeof(int));
// Don't support out-of-place for now
// TESTCHECK(BenchTime(args, 0));
TESTCHECK(BenchTime(args, 1));
TESTCHECK(BenchTime(args, args->in_place));
PRINT("\n");
}
return testSuccess;
@@ -644,6 +646,7 @@ testResult_t run()
worker.args.stepfactor = stepFactor;
worker.args.localRank = localRank;
worker.args.nranksPerNode = nranksPerNode;
worker.args.in_place = 1;
worker.args.totalProcs = totalProcs;
worker.args.proc = proc;

View File

@@ -64,11 +64,17 @@ typedef enum
testNumResults = 5
} testResult_t;
inline testResult_t defaultInitColl()
{
return testSuccess;
}
struct testColl
{
const char name[20];
void (*getCollByteCount)(size_t* sendcount, size_t* recvcount, size_t* paramcount, size_t* sendInplaceOffset,
size_t* recvInplaceOffset, size_t count, int nranks);
testResult_t (*initColl)();
testResult_t (*initData)(struct testArgs* args, int in_place);
void (*getBw)(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks);
testResult_t (*runColl)(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
@@ -100,6 +106,7 @@ struct testArgs
int localRank;
int nranksPerNode;
int kernel_num;
int in_place;
void* sendbuff;
size_t sendBytes;
size_t sendInplaceOffset;
@@ -153,4 +160,4 @@ inline void print_usage(const char* prog)
if (is_main_thread) \
printf
#endif // MSCCLPP_TESTS_COMMON_H_
#endif // MSCCLPP_TESTS_COMMON_H_

View File

@@ -1,134 +0,0 @@
#include "comm.h"
#include "common.h"
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unistd.h>
#define BLOCK_THREADS_NUM 128
#define ALIGN 4
__global__ void initKernel(int* dataDst, int dataCount)
{
for (size_t i = threadIdx.x; i < dataCount; i += blockDim.x) {
dataDst[i] = i % 256;
}
}
__constant__ mscclppDevConn_t sendConnConst;
__constant__ mscclppDevConn_t recvConnConst;
__global__ void kernel(bool root, size_t dataSize)
{
mscclppDevConn_t sendConn = sendConnConst;
mscclppDevConn_t recvConn = recvConnConst;
if (root) {
sendConn.putDirect(0, dataSize, threadIdx.x, blockDim.x);
// make sure all the threads have put their data
__syncthreads();
if (threadIdx.x == 0) {
sendConn.signalDirect();
recvConn.waitDirect();
}
} else {
if (threadIdx.x == 0) {
recvConn.waitDirect();
}
// make sure we get the latest data
__syncthreads();
sendConn.putDirect(0, dataSize, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == 0) {
sendConn.signalDirect();
}
}
}
testResult_t resetData(int* dataDst, size_t dataCount, bool isRoot)
{
if (isRoot) {
initKernel<<<1, BLOCK_THREADS_NUM>>>(dataDst, dataCount);
} else {
CUDACHECK(cudaMemset(dataDst, 0, dataCount * sizeof(int)));
}
return testSuccess;
}
void RingSendRecvGetCollByteCount(size_t* sendcount, size_t* recvcount, size_t* paramcount, size_t* sendInplaceOffset,
size_t* recvInplaceOffset, size_t count, int nranks)
{
size_t base = (count / ALIGN) * ALIGN;
*sendcount = base;
*recvcount = base;
*sendInplaceOffset = base;
*recvInplaceOffset = 0;
*paramcount = base;
}
testResult_t RingSendRecvInitData(struct testArgs* args, int in_place)
{
size_t recvcount = args->expectedBytes / sizeof(int);
CUDACHECK(cudaSetDevice(args->gpuNum));
int rank = args->proc;
CUDACHECK(cudaMemset(args->recvbuff, 0, args->expectedBytes));
resetData((int*)args->recvbuff, recvcount, rank == 0);
int* dataHost = new int[recvcount];
for (size_t i = 0; i < recvcount; i++) {
dataHost[i] = i % 256;
}
CUDACHECK(cudaMemcpy(args->expected, dataHost, recvcount * sizeof(int), cudaMemcpyHostToDevice));
delete dataHost;
CUDACHECK(cudaDeviceSynchronize());
MSCCLPPCHECK(mscclppBootstrapBarrier(args->comm));
return testSuccess;
}
void RingSendRecvGetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks)
{
double baseBw = (double)(count * typesize * nranks) / 1.0E9 / sec;
*algBw = baseBw;
double factor = ((double)(nranks - 1)) / ((double)nranks);
*busBw = baseBw * factor;
}
testResult_t RingSendRecvRunColl(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
cudaStream_t stream, int kernel_num)
{
kernel<<<1, BLOCK_THREADS_NUM, 0, stream>>>(comm->rank == 0, count);
return testSuccess;
}
struct testColl ringSendRecvTest = {"RingSendRecvTest", RingSendRecvGetCollByteCount, RingSendRecvInitData,
RingSendRecvGetBw, RingSendRecvRunColl};
void RingSendRecvGetBuffSize(size_t* sendcount, size_t* recvcount, size_t count, int nranks)
{
size_t paramcount, sendInplaceOffset, recvInplaceOffset;
RingSendRecvGetCollByteCount(sendcount, recvcount, &paramcount, &sendInplaceOffset, &recvInplaceOffset, count,
nranks);
}
testResult_t RingSendRecvRunTest(struct testArgs* args)
{
args->collTest = &ringSendRecvTest;
int rank = args->proc, worldSize = args->totalProcs;
mscclppDevConn_t* sendDevConn;
mscclppDevConn_t* recvDevConn;
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank + 1) % worldSize, 0, &sendDevConn));
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank - 1 + worldSize) % worldSize, 0, &recvDevConn));
CUDACHECK(cudaMemcpyToSymbol(sendConnConst, sendDevConn, sizeof(mscclppDevConn_t)));
CUDACHECK(cudaMemcpyToSymbol(recvConnConst, recvDevConn, sizeof(mscclppDevConn_t)));
TESTCHECK(TimeTest(args));
return testSuccess;
}
struct testEngine ringSendRecvTestEngine = {RingSendRecvGetBuffSize, RingSendRecvRunTest, nullptr, nullptr};
#pragma weak mscclppTestEngine = ringSendRecvTestEngine

203
tests/sendrecv_test.cu Normal file
View File

@@ -0,0 +1,203 @@
#include "comm.h"
#include "common.h"
#include <algorithm>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unistd.h>
constexpr size_t BLOCK_THREADS_NUM = 1024;
// Try to use more blocks if per-block data size exceeds this threshold
constexpr size_t THRES_BYTES_PER_BLOCK = 8192;
// Let it no more than the number of SMs on a GPU
constexpr size_t MAX_BLOCKS_NUM = 32;
#define ALIGN 4
__constant__ mscclppDevConn_t sendConnConst;
__constant__ mscclppDevConn_t recvConnConst;
struct SyncGpuState
{
volatile int flag;
int cnt;
int is_add;
};
// Synchronize multiple thread blocks inside a kernel. Guarantee that all
// previous work of all threads in cooperating blocks is finished and
// visible to all threads in the device.
__forceinline__ __device__ void sync_gpu(SyncGpuState& state, int blockNum)
{
int maxOldCnt = blockNum - 1;
__syncthreads();
if (threadIdx.x == 0) {
int is_add_ = state.is_add ^ 1;
if (is_add_) {
if (atomicAdd(&state.cnt, 1) == maxOldCnt) {
state.flag = 1;
}
while (!state.flag) {
}
} else {
if (atomicSub(&state.cnt, 1) == 1) {
state.flag = 0;
}
while (state.flag) {
}
}
state.is_add = is_add_;
}
// We need sync here because only a single thread is checking whether
// the flag is flipped.
__syncthreads();
}
inline int getSendTag(int rank, int peer)
{
return rank < peer ? 0 : 1;
}
inline int getRecvTag(int rank, int peer)
{
return rank < peer ? 1 : 0;
}
inline int getBlockNum(size_t count)
{
return std::min((count + THRES_BYTES_PER_BLOCK - 1) / THRES_BYTES_PER_BLOCK, MAX_BLOCKS_NUM);
}
__device__ SyncGpuState GLOBAL_SYNC_STATE;
__global__ void kernel(int rank, size_t dataSize, size_t dataPerBlock)
{
mscclppDevConn_t sendConn = sendConnConst;
mscclppDevConn_t recvConn = recvConnConst;
size_t startIndex = blockIdx.x * dataPerBlock;
size_t blockDataSize = min(dataSize - startIndex, dataPerBlock);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
sendConn.putDirect(startIndex, blockDataSize, threadIdx.x, blockDim.x);
sync_gpu(GLOBAL_SYNC_STATE, gridDim.x);
if (tid == 0) {
sendConn.signalDirect();
recvConn.waitDirect();
}
}
void SendRecvGetCollByteCount(size_t* sendcount, size_t* recvcount, size_t* paramcount, size_t* sendInplaceOffset,
size_t* recvInplaceOffset, size_t count, int nranks)
{
size_t base = (count / ALIGN) * ALIGN;
*sendcount = base;
*recvcount = base;
*sendInplaceOffset = base;
*recvInplaceOffset = 0;
*paramcount = base;
}
testResult_t SendRecvInitColl()
{
SyncGpuState state = {};
CUDACHECK(cudaMemcpyToSymbol(GLOBAL_SYNC_STATE, &state, sizeof(SyncGpuState)));
return testSuccess;
}
testResult_t SendRecvInitData(struct testArgs* args, int in_place)
{
size_t sendCount = args->sendBytes / sizeof(int);
size_t recvCount = args->expectedBytes / sizeof(int);
size_t maxCount = std::max(sendCount, recvCount);
int rank = args->proc;
CUDACHECK(cudaMemset(args->sendbuff, 0, args->sendBytes));
std::vector<int> dataHost(maxCount, rank);
CUDACHECK(cudaMemcpy(args->sendbuff, dataHost.data(), sendCount * sizeof(int), cudaMemcpyHostToDevice));
int recvPeerRank = (rank - 1 + args->totalProcs) % args->totalProcs;
for (size_t i = 0; i < recvCount; i++) {
dataHost[i] = recvPeerRank;
}
CUDACHECK(cudaMemcpy(args->expected, dataHost.data(), recvCount * sizeof(int), cudaMemcpyHostToDevice));
MSCCLPPCHECK(mscclppBootstrapBarrier(args->comm));
return testSuccess;
}
void SendRecvGetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks)
{
double baseBw = (double)(count * typesize) / 1.0E9 / sec;
*algBw = baseBw;
double factor = 1;
*busBw = baseBw * factor;
}
testResult_t SendRecvRunColl(void* sendbuff, void* recvbuff, int nranksPerNode, size_t count, mscclppComm_t comm,
cudaStream_t stream, int kernel_num)
{
int blockNum = getBlockNum(count);
size_t bytesPerBlock = (count + blockNum - 1) / blockNum;
kernel<<<blockNum, BLOCK_THREADS_NUM, 0, stream>>>(comm->rank, count, bytesPerBlock);
return testSuccess;
}
struct testColl sendRecvTest = {"SendRecvTest", SendRecvGetCollByteCount, SendRecvInitColl, SendRecvInitData,
SendRecvGetBw, SendRecvRunColl};
void SendRecvGetBuffSize(size_t* sendcount, size_t* recvcount, size_t count, int nranks)
{
size_t paramcount, sendInplaceOffset, recvInplaceOffset;
SendRecvGetCollByteCount(sendcount, recvcount, &paramcount, &sendInplaceOffset, &recvInplaceOffset, count, nranks);
}
testResult_t SendRecvSetupConnections(struct testArgs* args)
{
int rank = args->proc;
int worldSize = args->totalProcs;
int ranksPerNode = args->nranksPerNode;
int thisNode = rank / ranksPerNode;
int localRank = rank % ranksPerNode;
std::string ibDevStr = "mlx5_ib" + std::to_string(localRank);
int sendToRank = (rank + 1) % worldSize;
int recvFromRank = (rank - 1 + worldSize) % worldSize;
std::array<int, 2> ranks = {sendToRank, recvFromRank};
for (int i = 0; i < 2; i++) {
int r = ranks[i];
const char* ibDev = r / ranksPerNode == thisNode ? nullptr : ibDevStr.c_str();
mscclppTransport_t transportType = ibDev == nullptr ? mscclppTransportP2P : mscclppTransportIB;
void* buff = (i == 0) ? args->sendbuff : args->recvbuff;
int tag = (i == 0) ? getSendTag(rank, r) : getRecvTag(rank, r);
MSCCLPPCHECK(mscclppConnect(args->comm, r, tag, buff, args->maxbytes, transportType, ibDev));
}
MSCCLPPCHECK(mscclppConnectionSetup(args->comm));
return testSuccess;
}
testResult_t SendRecvRunTest(struct testArgs* args)
{
args->collTest = &sendRecvTest;
int rank = args->proc, worldSize = args->totalProcs;
// only support out-of-place for sendrecv test
args->in_place = 0;
mscclppDevConn_t* sendDevConn;
mscclppDevConn_t* recvDevConn;
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank + 1) % worldSize, getSendTag(rank, (rank + 1) % worldSize),
&sendDevConn));
MSCCLPPCHECK(mscclppGetDeviceConnection(args->comm, (rank - 1 + worldSize) % worldSize,
getRecvTag(rank, (rank - 1 + worldSize) % worldSize), &recvDevConn));
CUDACHECK(cudaMemcpyToSymbol(sendConnConst, sendDevConn, sizeof(mscclppDevConn_t)));
CUDACHECK(cudaMemcpyToSymbol(recvConnConst, recvDevConn, sizeof(mscclppDevConn_t)));
TESTCHECK(TimeTest(args));
return testSuccess;
}
struct testEngine sendRecvTestEngine = {SendRecvGetBuffSize, SendRecvRunTest, SendRecvSetupConnections, nullptr};
#pragma weak mscclppTestEngine = sendRecvTestEngine