allgather_test code cleanup

This commit is contained in:
Madan Musuvathi
2023-03-22 20:38:29 +00:00
parent 261fd7f838
commit 4c459aa0df
4 changed files with 42 additions and 41 deletions

View File

@@ -8,15 +8,6 @@
#endif
#include <stdint.h>
#define MSCCLPPCHECK(call) do { \
mscclppResult_t res = call; \
if (res != mscclppSuccess && res != mscclppInProgress) { \
/* Print the back trace*/ \
printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \
return res; \
} \
} while (0)
#define MSCCLPP_MAJOR 0
#define MSCCLPP_MINOR 1
#define MSCCLPP_PROXY_FIFO_SIZE 8

View File

@@ -170,7 +170,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, mscclppDevConn* devConnOut, int remoteRank,
void* localBuff, size_t buffSize, int tag, mscclppTransport_t transportType, const char *ibDev);
mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, int remoteRank, void* localBuff, size_t buffSize,
int tag, mscclppTransport_t transportType, const char *ibDev/*=NULL*/)
int tag, mscclppTransport_t transportType, const char *ibDev)
{
if (comm->nConns == MAXCONNECTIONS) {
WARN("Too many connections made");

View File

@@ -1,4 +1,5 @@
#include "mscclpp.h"
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
#include "mpi.h"
#endif // MSCCLPP_USE_MPI_FOR_TESTS
@@ -9,6 +10,18 @@
#define RANKS_PER_NODE 8
// Propagate errors up
#define MSCCLPPCHECK(call) do { \
mscclppResult_t res = call; \
if (res != mscclppSuccess && res != mscclppInProgress) { \
/* Print the back trace*/ \
printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \
return res; \
} \
} while (0)
// Check CUDA RT calls
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
@@ -29,6 +42,8 @@ static double getTime(void)
return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec;
}
mscclppComm_t comm;
mscclppDevConn_t devConns[16];
__constant__ mscclppDevConn_t constDevConns[16];
__global__ void kernel(int rank, int world_size, int nelemsPerGPU)
@@ -103,6 +118,31 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t data_si
CUDACHECK(cudaMemcpy(*data_d, *data_h, data_size, cudaMemcpyHostToDevice));
}
mscclppResult_t setupMscclppConnections(int rank, int world_size, mscclppComm_t comm, int* data_d, size_t data_size){
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
mscclppTransport_t transportType;
const char* ibDev = ibDevStr.c_str();
if (rankToNode(r) == thisNode){
ibDev = NULL;
transportType = mscclppTransportP2P;
} else {
transportType = mscclppTransportIB;
}
// Connect with all other ranks
MSCCLPPCHECK(mscclppConnect(comm, &devConns[r], r, data_d, data_size, 0, transportType, ibDev));
}
MSCCLPPCHECK(mscclppConnectionSetup(comm));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * world_size));
return mscclppSuccess;
}
int main(int argc, const char *argv[])
{
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
@@ -133,11 +173,8 @@ int main(int argc, const char *argv[])
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
CUDACHECK(cudaSetDevice(cudaNum));
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
mscclppComm_t comm;
MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port));
int *data_d;
@@ -147,31 +184,13 @@ int main(int argc, const char *argv[])
initializeAndAllocateAllGatherData(rank, world_size, data_size, nelemsPerGPU, &data_h, &data_d);
mscclppDevConn_t devConns[16];
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
mscclppTransport_t transportType;
const char* ibDev = ibDevStr.c_str();
if (rankToNode(r) == thisNode){
ibDev = NULL;
transportType = mscclppTransportP2P;
} else {
transportType = mscclppTransportIB;
}
// Connect with all other ranks
MSCCLPPCHECK(mscclppConnect(comm, &devConns[r], r, data_d, data_size, 0, transportType, ibDev));
}
MSCCLPPCHECK(mscclppConnectionSetup(comm));
MSCCLPPCHECK(setupMscclppConnections(rank, world_size, comm, data_d, data_size));
MSCCLPPCHECK(mscclppProxyLaunch(comm));
CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * world_size));
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaDeviceSynchronize());
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU);
CUDACHECK(cudaDeviceSynchronize());

View File

@@ -9,15 +9,6 @@
#define RANKS_PER_NODE 8
#define MSCCLPPCHECK(call) do { \
mscclppResult_t res = call; \
if (res != mscclppSuccess && res != mscclppInProgress) { \
/* Print the back trace*/ \
printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \
return res; \
} \
} while (0);
// Check CUDA RT calls
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \