diff --git a/Makefile b/Makefile index b0c78aa1..a488d198 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ LIBSONAME := $(LIBNAME).$(MSCCLPP_MAJOR) LIBTARGET := $(BUILDDIR)/$(LIBDIR)/$(LIBNAME).$(MSCCLPP_MAJOR).$(MSCCLPP_MINOR) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc p2p_test.cu allgather_test.cu allgather_test2.cu) +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc p2p_test.cu allgather_test.cu allgather_test2.cu allreduce_allpairs_test.cu) TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/ib.cc b/src/ib.cc index 51a5820c..8eef4182 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -342,7 +342,7 @@ int mscclppIbQp::rts() } int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, - uint64_t wrId, uint64_t offset, bool signaled) + uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { return -1; @@ -357,10 +357,10 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info wr_->num_sge = 1; wr_->opcode = IBV_WR_RDMA_WRITE; wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + offset; + wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset; wr_->wr.rdma.rkey = info->rkey; wr_->next = nullptr; - sge_->addr = (uint64_t)(ibMr->buff) + offset; + sge_->addr = (uint64_t)(ibMr->buff) + srcOffset; sge_->length = size; sge_->lkey = ibMr->mr->lkey; if (wrn > 0) { @@ -371,9 +371,9 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info } int mscclppIbQp::stageSendWithImm(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, - uint64_t wrId, uint64_t offset, bool signaled, unsigned int immData) + uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - int wrn = this->stageSend(ibMr, info, size, wrId, offset, signaled); + int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled); this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; this->wrs[wrn - 1].imm_data = immData; return wrn; diff --git a/src/include/ib.h b/src/include/ib.h index a24891b8..8dc615c6 100644 --- a/src/include/ib.h +++ b/src/include/ib.h @@ -48,9 +48,9 @@ struct mscclppIbQp { int rtr(const mscclppIbQpInfo *info); int rts(); int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, - uint64_t wrId, uint64_t offset, bool signaled); + uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); int stageSendWithImm(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, - uint64_t wrId, uint64_t offset, bool signaled, unsigned int immData); + uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); int postSend(); int postRecv(uint64_t wrId); int pollCq(); diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index a4035710..54a3a847 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -32,13 +32,14 @@ union alignas(16) mscclppTrigger { uint64_t value[2]; struct { // first 64 bits: value[0] - uint64_t dataSize : MSCCLPP_BITS_SIZE; - uint64_t dataOffset : MSCCLPP_BITS_OFFSET; - uint64_t : (64-MSCCLPP_BITS_SIZE-MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + uint64_t dataSize : MSCCLPP_BITS_SIZE; + uint64_t srcDataOffset : MSCCLPP_BITS_OFFSET; + uint64_t : (64-MSCCLPP_BITS_SIZE-MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment // second 64 bits: value[1] - uint64_t connId : MSCCLPP_BITS_CONNID; - uint64_t type : MSCCLPP_BITS_TYPE; - uint64_t : (64-MSCCLPP_BITS_CONNID-MSCCLPP_BITS_TYPE); // ensure 64-bit alignment + uint64_t dstDataOffset : MSCCLPP_BITS_OFFSET; + uint64_t connId : MSCCLPP_BITS_CONNID; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t : (64-MSCCLPP_BITS_OFFSET-MSCCLPP_BITS_CONNID-MSCCLPP_BITS_TYPE); // ensure 64-bit alignment } fields; }; @@ -54,12 +55,16 @@ struct mscclppConcurrentFifo { return curFifoHead; } - __forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) { + __forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t srcDataOffset, uint64_t dstDataOffset, uint64_t dataSize) { asm volatile( "st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value), - "l"((dataOffset << (MSCCLPP_BITS_SIZE)) + - (dataSize)), - "l"((type << MSCCLPP_BITS_CONNID) + this->connId)); + "l"((srcDataOffset << MSCCLPP_BITS_SIZE) + dataSize), + "l"((((type << MSCCLPP_BITS_CONNID) + this->connId) << MSCCLPP_BITS_OFFSET) + dstDataOffset) + ); + } + + __forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) { + setTrigger(trig, type, dataOffset, dataOffset, dataSize); } __forceinline__ __device__ void waitTrigger(mscclppRequest_t req) { @@ -195,6 +200,10 @@ mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm); mscclppResult_t mscclppProxyStop(mscclppComm_t comm); +mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank); + +mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size); + #ifdef __cplusplus } // end extern "C" #endif diff --git a/src/init.cc b/src/init.cc index 8aa975ef..d2d88c83 100644 --- a/src/init.cc +++ b/src/init.cc @@ -402,3 +402,25 @@ mscclppResult_t mscclppProxyStop(mscclppComm_t comm) MSCCLPPCHECK(mscclppProxyDestroy(comm)); return mscclppSuccess; } + +MSCCLPP_API(mscclppResult_t, mscclppCommRank, mscclppComm_t comm, int* rank); +mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank) +{ + if (comm == NULL || rank == NULL) { + WARN("comm or rank cannot be null"); + return mscclppInvalidUsage; + } + *rank = comm->rank; + return mscclppSuccess; +} + +MSCCLPP_API(mscclppResult_t, mscclppCommSize, mscclppComm_t comm, int* size); +mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size) +{ + if (comm == NULL || size == NULL) { + WARN("comm or size cannot be null"); + return mscclppInvalidUsage; + } + *size = comm->nRanks; + return mscclppSuccess; +} \ No newline at end of file diff --git a/src/proxy.cc b/src/proxy.cc index c11c053b..a6bd39b3 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -81,8 +81,8 @@ void* mscclppProxyServiceP2P(void* _args) { // Iterate over what send is needed if (trigger.fields.type & mscclppData){ - void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.dataOffset); - void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dataOffset); + void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.srcDataOffset); + void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dstDataOffset); PROXYCUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, trigger.fields.dataSize, cudaMemcpyDeviceToDevice, stream)); } if (trigger.fields.type & mscclppFlag) { @@ -222,12 +222,13 @@ void* mscclppProxyServiceIb(void* _args) { if (trigger.fields.type & mscclppData) { conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, - /*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false); + /*wrId=*/0, /*srcOffset=*/trigger.fields.srcDataOffset, /*dstOffset=*/trigger.fields.dstDataOffset, + /*signaled=*/false); } if (trigger.fields.type & mscclppFlag) { // My local flag is copied to the peer's proxy flag conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibProxyFlagMrInfo, sizeof(uint64_t), - /*wrId=*/0, /*offset=*/0, /*signaled=*/true); + /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/0, /*signaled=*/true); } int ret; if ((ret = conn->ibQp->postSend()) != 0) { diff --git a/tests/allgather_test2.cu b/tests/allgather_test2.cu index 290b4e9c..03df81f8 100644 --- a/tests/allgather_test2.cu +++ b/tests/allgather_test2.cu @@ -89,9 +89,9 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) // Trigger sending data, flag and synchronize after int ibPortion = nelemsPerGPU/12;//nelemsPerGPU/12; if (isIB) - devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), ibPortion*sizeof(int)); + devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), ibPortion*sizeof(int)); else - devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU-ibPortion)*sizeof(int)); + devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU-ibPortion)*sizeof(int)); // Wait on the request to make sure it is safe to reuse buffer and flag devConn.fifo.waitTrigger(req); } diff --git a/tests/allreduce_allpairs_test.cu b/tests/allreduce_allpairs_test.cu new file mode 100644 index 00000000..2ad4159a --- /dev/null +++ b/tests/allreduce_allpairs_test.cu @@ -0,0 +1,297 @@ +#include "mscclpp.h" +#include +#include +#include + +#include "common.h" + +#define MSCCLPPCHECK(call) do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \ + return res; \ + } \ +} while (0); + +#define CUDACHECK(cmd) do { \ + cudaError_t err = cmd; \ + if( err != cudaSuccess ) { \ + printf("%s:%d Cuda failure '%s'\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ +} while(false) + +struct Volume { + size_t offset; + size_t size; +}; + +__host__ __device__ Volume chunkVolume(size_t totalSize, size_t totalChunks, size_t chunkIdx, size_t chunkCount) { + size_t remainder = totalSize % totalChunks; + size_t smallChunk = totalSize / totalChunks; + size_t largeChunk = smallChunk + 1; + size_t numLargeChunks = chunkIdx < remainder ? remainder - chunkIdx : 0; + size_t numSmallChunks = chunkCount - numLargeChunks; + size_t offset = (remainder - numLargeChunks) * largeChunk + + (chunkIdx > remainder ? chunkIdx - remainder : 0) * smallChunk; + return Volume{offset, numLargeChunks * largeChunk + numSmallChunks * smallChunk}; +} + +template +struct AllreduceAllpairs { + int rank; + int nRanks; + T* userData; + size_t userSize; + T* scratch; + size_t scratchSize; + mscclppDevConn_t* conns; + uint64_t* connFlags; + cuda::barrier* barrier; + + __device__ void run(int idx) { + int myPeer = peerRank(idx, rank); + mscclppDevConn_t phase1SendConn = conns[phase1SendConnIdx(myPeer)]; + mscclppDevConn_t phase1RecvConn = conns[phase1RecvConnIdx(myPeer)]; + mscclppDevConn_t phase2Conn = conns[phase2ConnIdx(myPeer)]; + + // 1st communication phase: send data to the scratch buffer of the peer associated with this block + Volume toPeer = chunkVolume(userSize, nRanks, myPeer, 1); + // Now we need to figure out the offset of this chunk in the scratch buffer of the destination. + // The destination will have allocated a scratch buffer of size numPeers() * toPeer.size and + // inside that each of the destination's peers send to the nth chunk, where n is the index of the + // source peer from the destination's perspective. + size_t dstOffset = peerIdx(rank, myPeer) * toPeer.size; + send(phase1SendConn, toPeer.offset, dstOffset, toPeer.size); + recv(phase1RecvConn); + + if (threadIdx.x == 0) + barrier->arrive_and_wait(); + __syncthreads(); + + // Local reduction: every block reduces a slice of each chunk in the scratch buffer into the user buffer + Volume rankUserChunk = chunkVolume(userSize, nRanks, rank, 1); + T* userChunk = userData + rankUserChunk.offset; + Volume blockUserChunk = chunkVolume(rankUserChunk.size, numBlocks(), idx, 1); + for (int peerIdx = 0; peerIdx < numPeers(); ++peerIdx) { + assert(scratchSize % numPeers() == 0); + assert(scratchSize / numPeers() == rankUserChunk.size); + size_t scratchChunkSize = scratchSize / numPeers(); + T* scratchChunk = scratch + peerIdx * scratchChunkSize; + Volume blockScratchChunk = chunkVolume(scratchChunkSize, numBlocks(), idx, 1); + assert(blockScratchChunk.size == blockUserChunk.size); + reduce(userChunk + blockUserChunk.offset, scratchChunk + blockScratchChunk.offset, blockScratchChunk.size); + } + + if (threadIdx.x == 0) + barrier->arrive_and_wait(); + __syncthreads(); + + // 2nd communication phase: send the now reduced data between the user buffers + Volume srcVolume2 = chunkVolume(userSize, nRanks, rank, 1); + send(phase2Conn, srcVolume2.offset, srcVolume2.offset, srcVolume2.size); + recv(phase2Conn); + + } + + __device__ void send(mscclppDevConn_t& conn, size_t srcOffset, size_t dstOffset, size_t size) { + if (threadIdx.x == 0) { + volatile uint64_t *localFlag = conn.localFlag; + *localFlag = 1; // 1 is used to signal the send + + mscclppTrigger_t trigger; + auto request = conn.fifo.getTrigger(&trigger); + conn.fifo.setTrigger(trigger, mscclppData | mscclppFlag, srcOffset * sizeof(T), dstOffset * sizeof(T), size * sizeof(T)); + } + __syncthreads(); + } + + __device__ void recv(mscclppDevConn_t& conn) { + if (threadIdx.x == 0) { + volatile uint64_t *proxyFlag = conn.proxyFlag; + while (*proxyFlag != 1) {} + *proxyFlag = 0; + } + __syncthreads(); + } + + __host__ __device__ int numPeers() { + return nRanks - 1; + } + + __host__ __device__ int numBlocks() { + return numPeers(); + } + + __host__ __device__ int peerIdx(int peerRank, int myRank) { + return peerRank < myRank ? peerRank : peerRank - 1; + } + + __host__ __device__ int peerRank(int peerIdx, int myRank) { + return peerIdx < myRank ? peerIdx : peerIdx + 1; + } + + __host__ __device__ int phase1SendConnIdx(int peerRank) { + return peerIdx(peerRank, rank) * 3; + } + + __host__ __device__ int phase1RecvConnIdx(int peerRank) { + return peerIdx(peerRank, rank) * 3 + 1; + } + + __host__ __device__ int phase2ConnIdx(int peerRank) { + return peerIdx(peerRank, rank) * 3 + 2; + } + + void freeGPUResources() { + if (scratch) + CUDACHECK(cudaFree(scratch)); + scratch = nullptr; + if (connFlags) + CUDACHECK(cudaFree(connFlags)); + connFlags = nullptr; + if (conns) + CUDACHECK(cudaFree(conns)); + conns = nullptr; + if (barrier) + CUDACHECK(cudaFree(barrier)); + barrier = nullptr; + } +}; + +// The builder class encapsulates the +template +class AllreduceAllpairsBuilder { + AllreduceAllpairs d; + std::vector hostConns; + +public: + + // The constructor is called after the user has allocated the buffer to be allreduced + AllreduceAllpairsBuilder(T* data, size_t size) { + d.userData = data; + d.userSize = size; + d.scratch = nullptr; + d.connFlags = nullptr; + d.conns = nullptr; + d.barrier = nullptr; + } + + // connect is called after rank initialization but before connection setup + mscclppResult_t connect(mscclppComm_t comm) { + MSCCLPPCHECK(mscclppCommRank(comm, &d.rank)); + MSCCLPPCHECK(mscclppCommSize(comm, &d.nRanks)); + + Volume myChunks = chunkVolume(d.userSize, d.nRanks, d.rank, 1); + d.scratchSize = myChunks.size * d.numPeers(); + + CUDACHECK(cudaMalloc(&d.scratch, d.scratchSize * sizeof(T))); + CUDACHECK(cudaMalloc(&d.connFlags, 3 * sizeof(uint64_t))); + CUDACHECK(cudaMemset(d.connFlags, 0, 3 * sizeof(uint64_t))); + + hostConns.resize(d.numPeers() * 3); + for (int peer = 0; peer < d.nRanks; ++peer) { + if (peer != d.rank) { + int sendTag = d.rank < peer ? 0 : 1; + int recvTag = d.rank < peer ? 1 : 0; + MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase1SendConnIdx(peer), peer, d.userData, d.userSize * sizeof(T), d.connFlags + 0, sendTag, mscclppTransportP2P, nullptr)); + MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase1RecvConnIdx(peer), peer, d.scratch, d.scratchSize * sizeof(T), d.connFlags + 1, recvTag, mscclppTransportP2P, nullptr)); + MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase2ConnIdx(peer), peer, d.userData, d.userSize * sizeof(T), d.connFlags + 2, 2, mscclppTransportP2P, nullptr)); + } + } + + return mscclppSuccess; + } + + // finishSetup is called after connection setup and returns an algorithm object that is ready to be passed to a GPU kernel + AllreduceAllpairs finishSetup() { + CUDACHECK(cudaMalloc(&d.conns, hostConns.size() * sizeof(mscclppDevConn_t))); + CUDACHECK(cudaMemcpy(d.conns, hostConns.data(), hostConns.size() * sizeof(mscclppDevConn_t), cudaMemcpyHostToDevice)); + CUDACHECK(cudaMalloc(&d.barrier, sizeof(cuda::barrier))); + cuda::barrier initBarrier(d.numBlocks()); + CUDACHECK(cudaMemcpy(d.barrier, &initBarrier, sizeof(cuda::barrier), cudaMemcpyHostToDevice)); + return d; + } +}; + +template +__device__ void reduceSum(T* dst, T* src, size_t size) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + dst[i] += src[i]; + } +} + +template +__global__ void init(T* data, size_t size, int rank) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + data[i] = rank; + } +} + +// The main test kernel +template +__global__ void testKernel(AllreduceAllpairs d) { + d.run(blockIdx.x); +} + +int main(int argc, const char *argv[]) { +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + MPI_Init(NULL, NULL); +#endif + const char* ip_port; + int rank, world_size; + parse_arguments(argc, argv, &ip_port, &rank, &world_size); + + CUDACHECK(cudaSetDevice(rank)); + + // Allocate and initialize 1 MB of data + int *data; + size_t dataSize = 1024 * 1024 / sizeof(int); + CUDACHECK(cudaMalloc(&data, dataSize * sizeof(int))); + init<<<1, 256>>>(data, dataSize, rank); + + // Create the collective + AllreduceAllpairsBuilder builder(data, dataSize); + + // Create the communicator + mscclppComm_t comm; + MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port)); + + // Connect the collective + builder.connect(comm); + + // Finish the setup + MSCCLPPCHECK(mscclppConnectionSetup(comm)); + MSCCLPPCHECK(mscclppProxyLaunch(comm)); + auto allreduce = builder.finishSetup(); + + // Run the collective + testKernel<<>>(allreduce); + + // Wait for kernel to finish + CUDACHECK(cudaDeviceSynchronize()); + + // Check the result + int* hostData = new int[dataSize]; + CUDACHECK(cudaMemcpy(hostData, data, dataSize * sizeof(int), cudaMemcpyDeviceToHost)); + int expectedValue = world_size * (world_size - 1) / 2; + for (size_t i = 0; i < dataSize; ++i) { + if (hostData[i] != expectedValue) { + printf("Error at index %lu: %d != %d\n", i, hostData[i], expectedValue); + return 1; + } + } + + MSCCLPPCHECK(mscclppProxyStop(comm)); + + MSCCLPPCHECK(mscclppCommDestroy(comm)); + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + if (argc == 2) { + MPI_Finalize(); + } +#endif + printf("Succeeded! %d\n", rank); + return 0; +} \ No newline at end of file diff --git a/tests/common.h b/tests/common.h new file mode 100644 index 00000000..e3d6b8c9 --- /dev/null +++ b/tests/common.h @@ -0,0 +1,45 @@ +#ifndef MSCCLPP_TESTS_COMMON_H_ +#define MSCCLPP_TESTS_COMMON_H_ + +#include +#include + +#ifdef MSCCLPP_USE_MPI_FOR_TESTS +#include "mpi.h" +#endif // MSCCLPP_USE_MPI_FOR_TESTS + +void print_usage(const char *prog) +{ +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + printf("usage: %s IP:PORT [rank nranks]\n", prog); +#else + printf("usage: %s IP:PORT rank nranks\n", prog); +#endif +} + +void parse_arguments(int argc, const char *argv[], const char** ip_port, int* rank, int* world_size) { +#ifdef MSCCLPP_USE_MPI_FOR_TESTS + if (argc != 2 && argc != 4) { + print_usage(argv[0]); + exit(-1); + } + *ip_port = argv[1]; + if (argc == 4) { + *rank = atoi(argv[2]); + *world_size = atoi(argv[3]); + } else { + MPI_Comm_rank(MPI_COMM_WORLD, rank); + MPI_Comm_size(MPI_COMM_WORLD, world_size); + } +#else + if (argc != 4) { + print_usage(argv[0]); + exit(-1); + } + const char *ip_port = argv[1]; + *rank = atoi(argv[2]); + *world_size = atoi(argv[3]); +#endif +} + +#endif // MSCCLPP_TESTS_COMMON_H_ \ No newline at end of file diff --git a/tests/p2p_test.cu b/tests/p2p_test.cu index a780b21b..eb65ca05 100644 --- a/tests/p2p_test.cu +++ b/tests/p2p_test.cu @@ -1,12 +1,11 @@ #include "mscclpp.h" -#ifdef MSCCLPP_USE_MPI_FOR_TESTS -#include "mpi.h" -#endif // MSCCLPP_USE_MPI_FOR_TESTS #include #include #include #include +#include "common.h" + #define RANKS_PER_NODE 8 #define USE_DMA_FOR_P2P 1 #define TEST_CONN_TYPE 0 // 0: P2P(for local)+IB(for remote), 1: IB-Only @@ -147,42 +146,14 @@ int cudaNumToIbNum(int cudaNum) return ibNum; } -void print_usage(const char *prog) -{ -#ifdef MSCCLPP_USE_MPI_FOR_TESTS - printf("usage: %s IP:PORT [rank nranks]\n", prog); -#else - printf("usage: %s IP:PORT rank nranks\n", prog); -#endif -} - int main(int argc, const char *argv[]) { #ifdef MSCCLPP_USE_MPI_FOR_TESTS - if (argc != 2 && argc != 4) { - print_usage(argv[0]); - return -1; - } - const char *ip_port = argv[1]; - int rank; - int world_size; - if (argc == 4) { - rank = atoi(argv[2]); - world_size = atoi(argv[3]); - } else { - MPI_Init(NULL, NULL); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &world_size); - } -#else - if (argc != 4) { - print_usage(argv[0]); - return -1; - } - const char *ip_port = argv[1]; - int rank = atoi(argv[2]); - int world_size = atoi(argv[3]); + MPI_Init(NULL, NULL); #endif + const char* ip_port; + int rank, world_size; + parse_arguments(argc, argv, &ip_port, &rank, &world_size); int localRank = rankToLocalRank(rank); int thisNode = rankToNode(rank); int cudaNum = localRank;