mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
Add allpairs allreduce test
To support this include separate source and destination offsets in the trigger. Add functions for getting the rank and world size from a communicator.
This commit is contained in:
2
Makefile
2
Makefile
@@ -116,7 +116,7 @@ LIBSONAME := $(LIBNAME).$(MSCCLPP_MAJOR)
|
||||
LIBTARGET := $(BUILDDIR)/$(LIBDIR)/$(LIBNAME).$(MSCCLPP_MAJOR).$(MSCCLPP_MINOR)
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc p2p_test.cu allgather_test.cu allgather_test2.cu)
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc p2p_test.cu allgather_test.cu allgather_test2.cu allreduce_allpairs_test.cu)
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
10
src/ib.cc
10
src/ib.cc
@@ -342,7 +342,7 @@ int mscclppIbQp::rts()
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, uint64_t offset, bool signaled)
|
||||
uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled)
|
||||
{
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
@@ -357,10 +357,10 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + offset;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset;
|
||||
wr_->wr.rdma.rkey = info->rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + offset;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + srcOffset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = ibMr->mr->lkey;
|
||||
if (wrn > 0) {
|
||||
@@ -371,9 +371,9 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSendWithImm(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, uint64_t offset, bool signaled, unsigned int immData)
|
||||
uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData)
|
||||
{
|
||||
int wrn = this->stageSend(ibMr, info, size, wrId, offset, signaled);
|
||||
int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled);
|
||||
this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
this->wrs[wrn - 1].imm_data = immData;
|
||||
return wrn;
|
||||
|
||||
@@ -48,9 +48,9 @@ struct mscclppIbQp {
|
||||
int rtr(const mscclppIbQpInfo *info);
|
||||
int rts();
|
||||
int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, uint64_t offset, bool signaled);
|
||||
uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled);
|
||||
int stageSendWithImm(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, uint64_t offset, bool signaled, unsigned int immData);
|
||||
uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData);
|
||||
int postSend();
|
||||
int postRecv(uint64_t wrId);
|
||||
int pollCq();
|
||||
|
||||
@@ -32,13 +32,14 @@ union alignas(16) mscclppTrigger {
|
||||
uint64_t value[2];
|
||||
struct {
|
||||
// first 64 bits: value[0]
|
||||
uint64_t dataSize : MSCCLPP_BITS_SIZE;
|
||||
uint64_t dataOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64-MSCCLPP_BITS_SIZE-MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
uint64_t dataSize : MSCCLPP_BITS_SIZE;
|
||||
uint64_t srcDataOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t : (64-MSCCLPP_BITS_SIZE-MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment
|
||||
// second 64 bits: value[1]
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t : (64-MSCCLPP_BITS_CONNID-MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
uint64_t dstDataOffset : MSCCLPP_BITS_OFFSET;
|
||||
uint64_t connId : MSCCLPP_BITS_CONNID;
|
||||
uint64_t type : MSCCLPP_BITS_TYPE;
|
||||
uint64_t : (64-MSCCLPP_BITS_OFFSET-MSCCLPP_BITS_CONNID-MSCCLPP_BITS_TYPE); // ensure 64-bit alignment
|
||||
} fields;
|
||||
};
|
||||
|
||||
@@ -54,12 +55,16 @@ struct mscclppConcurrentFifo {
|
||||
return curFifoHead;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t srcDataOffset, uint64_t dstDataOffset, uint64_t dataSize) {
|
||||
asm volatile(
|
||||
"st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(&trig->value),
|
||||
"l"((dataOffset << (MSCCLPP_BITS_SIZE)) +
|
||||
(dataSize)),
|
||||
"l"((type << MSCCLPP_BITS_CONNID) + this->connId));
|
||||
"l"((srcDataOffset << MSCCLPP_BITS_SIZE) + dataSize),
|
||||
"l"((((type << MSCCLPP_BITS_CONNID) + this->connId) << MSCCLPP_BITS_OFFSET) + dstDataOffset)
|
||||
);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger_t trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
|
||||
setTrigger(trig, type, dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void waitTrigger(mscclppRequest_t req) {
|
||||
@@ -195,6 +200,10 @@ mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm);
|
||||
|
||||
mscclppResult_t mscclppProxyStop(mscclppComm_t comm);
|
||||
|
||||
mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank);
|
||||
|
||||
mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif
|
||||
|
||||
22
src/init.cc
22
src/init.cc
@@ -402,3 +402,25 @@ mscclppResult_t mscclppProxyStop(mscclppComm_t comm)
|
||||
MSCCLPPCHECK(mscclppProxyDestroy(comm));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppCommRank, mscclppComm_t comm, int* rank);
|
||||
mscclppResult_t mscclppCommRank(mscclppComm_t comm, int* rank)
|
||||
{
|
||||
if (comm == NULL || rank == NULL) {
|
||||
WARN("comm or rank cannot be null");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
*rank = comm->rank;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppCommSize, mscclppComm_t comm, int* size);
|
||||
mscclppResult_t mscclppCommSize(mscclppComm_t comm, int* size)
|
||||
{
|
||||
if (comm == NULL || size == NULL) {
|
||||
WARN("comm or size cannot be null");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
*size = comm->nRanks;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
@@ -81,8 +81,8 @@ void* mscclppProxyServiceP2P(void* _args) {
|
||||
|
||||
// Iterate over what send is needed
|
||||
if (trigger.fields.type & mscclppData){
|
||||
void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.dataOffset);
|
||||
void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dataOffset);
|
||||
void *srcBuff = (void *)((char *)conn->devConn->localBuff + trigger.fields.srcDataOffset);
|
||||
void *dstBuff = (void *)((char *)conn->devConn->remoteBuff + trigger.fields.dstDataOffset);
|
||||
PROXYCUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, trigger.fields.dataSize, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
if (trigger.fields.type & mscclppFlag) {
|
||||
@@ -222,12 +222,13 @@ void* mscclppProxyServiceIb(void* _args) {
|
||||
|
||||
if (trigger.fields.type & mscclppData) {
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize,
|
||||
/*wrId=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false);
|
||||
/*wrId=*/0, /*srcOffset=*/trigger.fields.srcDataOffset, /*dstOffset=*/trigger.fields.dstDataOffset,
|
||||
/*signaled=*/false);
|
||||
}
|
||||
if (trigger.fields.type & mscclppFlag) {
|
||||
// My local flag is copied to the peer's proxy flag
|
||||
conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibProxyFlagMrInfo, sizeof(uint64_t),
|
||||
/*wrId=*/0, /*offset=*/0, /*signaled=*/true);
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/0, /*signaled=*/true);
|
||||
}
|
||||
int ret;
|
||||
if ((ret = conn->ibQp->postSend()) != 0) {
|
||||
|
||||
@@ -89,9 +89,9 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
|
||||
// Trigger sending data, flag and synchronize after
|
||||
int ibPortion = nelemsPerGPU/12;//nelemsPerGPU/12;
|
||||
if (isIB)
|
||||
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), ibPortion*sizeof(int));
|
||||
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), rank * nelemsPerGPU * sizeof(int) + (nelemsPerGPU - ibPortion)*sizeof(int), ibPortion*sizeof(int));
|
||||
else
|
||||
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU-ibPortion)*sizeof(int));
|
||||
devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU-ibPortion)*sizeof(int));
|
||||
// Wait on the request to make sure it is safe to reuse buffer and flag
|
||||
devConn.fifo.waitTrigger(req);
|
||||
}
|
||||
|
||||
297
tests/allreduce_allpairs_test.cu
Normal file
297
tests/allreduce_allpairs_test.cu
Normal file
@@ -0,0 +1,297 @@
|
||||
#include "mscclpp.h"
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <cuda/barrier>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define MSCCLPPCHECK(call) do { \
|
||||
mscclppResult_t res = call; \
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
/* Print the back trace*/ \
|
||||
printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \
|
||||
return res; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
#define CUDACHECK(cmd) do { \
|
||||
cudaError_t err = cmd; \
|
||||
if( err != cudaSuccess ) { \
|
||||
printf("%s:%d Cuda failure '%s'\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while(false)
|
||||
|
||||
struct Volume {
|
||||
size_t offset;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
__host__ __device__ Volume chunkVolume(size_t totalSize, size_t totalChunks, size_t chunkIdx, size_t chunkCount) {
|
||||
size_t remainder = totalSize % totalChunks;
|
||||
size_t smallChunk = totalSize / totalChunks;
|
||||
size_t largeChunk = smallChunk + 1;
|
||||
size_t numLargeChunks = chunkIdx < remainder ? remainder - chunkIdx : 0;
|
||||
size_t numSmallChunks = chunkCount - numLargeChunks;
|
||||
size_t offset = (remainder - numLargeChunks) * largeChunk +
|
||||
(chunkIdx > remainder ? chunkIdx - remainder : 0) * smallChunk;
|
||||
return Volume{offset, numLargeChunks * largeChunk + numSmallChunks * smallChunk};
|
||||
}
|
||||
|
||||
template<class T, void (*reduce)(T*,T*,size_t)>
|
||||
struct AllreduceAllpairs {
|
||||
int rank;
|
||||
int nRanks;
|
||||
T* userData;
|
||||
size_t userSize;
|
||||
T* scratch;
|
||||
size_t scratchSize;
|
||||
mscclppDevConn_t* conns;
|
||||
uint64_t* connFlags;
|
||||
cuda::barrier<cuda::thread_scope_device>* barrier;
|
||||
|
||||
__device__ void run(int idx) {
|
||||
int myPeer = peerRank(idx, rank);
|
||||
mscclppDevConn_t phase1SendConn = conns[phase1SendConnIdx(myPeer)];
|
||||
mscclppDevConn_t phase1RecvConn = conns[phase1RecvConnIdx(myPeer)];
|
||||
mscclppDevConn_t phase2Conn = conns[phase2ConnIdx(myPeer)];
|
||||
|
||||
// 1st communication phase: send data to the scratch buffer of the peer associated with this block
|
||||
Volume toPeer = chunkVolume(userSize, nRanks, myPeer, 1);
|
||||
// Now we need to figure out the offset of this chunk in the scratch buffer of the destination.
|
||||
// The destination will have allocated a scratch buffer of size numPeers() * toPeer.size and
|
||||
// inside that each of the destination's peers send to the nth chunk, where n is the index of the
|
||||
// source peer from the destination's perspective.
|
||||
size_t dstOffset = peerIdx(rank, myPeer) * toPeer.size;
|
||||
send(phase1SendConn, toPeer.offset, dstOffset, toPeer.size);
|
||||
recv(phase1RecvConn);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
barrier->arrive_and_wait();
|
||||
__syncthreads();
|
||||
|
||||
// Local reduction: every block reduces a slice of each chunk in the scratch buffer into the user buffer
|
||||
Volume rankUserChunk = chunkVolume(userSize, nRanks, rank, 1);
|
||||
T* userChunk = userData + rankUserChunk.offset;
|
||||
Volume blockUserChunk = chunkVolume(rankUserChunk.size, numBlocks(), idx, 1);
|
||||
for (int peerIdx = 0; peerIdx < numPeers(); ++peerIdx) {
|
||||
assert(scratchSize % numPeers() == 0);
|
||||
assert(scratchSize / numPeers() == rankUserChunk.size);
|
||||
size_t scratchChunkSize = scratchSize / numPeers();
|
||||
T* scratchChunk = scratch + peerIdx * scratchChunkSize;
|
||||
Volume blockScratchChunk = chunkVolume(scratchChunkSize, numBlocks(), idx, 1);
|
||||
assert(blockScratchChunk.size == blockUserChunk.size);
|
||||
reduce(userChunk + blockUserChunk.offset, scratchChunk + blockScratchChunk.offset, blockScratchChunk.size);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
barrier->arrive_and_wait();
|
||||
__syncthreads();
|
||||
|
||||
// 2nd communication phase: send the now reduced data between the user buffers
|
||||
Volume srcVolume2 = chunkVolume(userSize, nRanks, rank, 1);
|
||||
send(phase2Conn, srcVolume2.offset, srcVolume2.offset, srcVolume2.size);
|
||||
recv(phase2Conn);
|
||||
|
||||
}
|
||||
|
||||
__device__ void send(mscclppDevConn_t& conn, size_t srcOffset, size_t dstOffset, size_t size) {
|
||||
if (threadIdx.x == 0) {
|
||||
volatile uint64_t *localFlag = conn.localFlag;
|
||||
*localFlag = 1; // 1 is used to signal the send
|
||||
|
||||
mscclppTrigger_t trigger;
|
||||
auto request = conn.fifo.getTrigger(&trigger);
|
||||
conn.fifo.setTrigger(trigger, mscclppData | mscclppFlag, srcOffset * sizeof(T), dstOffset * sizeof(T), size * sizeof(T));
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void recv(mscclppDevConn_t& conn) {
|
||||
if (threadIdx.x == 0) {
|
||||
volatile uint64_t *proxyFlag = conn.proxyFlag;
|
||||
while (*proxyFlag != 1) {}
|
||||
*proxyFlag = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__host__ __device__ int numPeers() {
|
||||
return nRanks - 1;
|
||||
}
|
||||
|
||||
__host__ __device__ int numBlocks() {
|
||||
return numPeers();
|
||||
}
|
||||
|
||||
__host__ __device__ int peerIdx(int peerRank, int myRank) {
|
||||
return peerRank < myRank ? peerRank : peerRank - 1;
|
||||
}
|
||||
|
||||
__host__ __device__ int peerRank(int peerIdx, int myRank) {
|
||||
return peerIdx < myRank ? peerIdx : peerIdx + 1;
|
||||
}
|
||||
|
||||
__host__ __device__ int phase1SendConnIdx(int peerRank) {
|
||||
return peerIdx(peerRank, rank) * 3;
|
||||
}
|
||||
|
||||
__host__ __device__ int phase1RecvConnIdx(int peerRank) {
|
||||
return peerIdx(peerRank, rank) * 3 + 1;
|
||||
}
|
||||
|
||||
__host__ __device__ int phase2ConnIdx(int peerRank) {
|
||||
return peerIdx(peerRank, rank) * 3 + 2;
|
||||
}
|
||||
|
||||
void freeGPUResources() {
|
||||
if (scratch)
|
||||
CUDACHECK(cudaFree(scratch));
|
||||
scratch = nullptr;
|
||||
if (connFlags)
|
||||
CUDACHECK(cudaFree(connFlags));
|
||||
connFlags = nullptr;
|
||||
if (conns)
|
||||
CUDACHECK(cudaFree(conns));
|
||||
conns = nullptr;
|
||||
if (barrier)
|
||||
CUDACHECK(cudaFree(barrier));
|
||||
barrier = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// The builder class encapsulates the
|
||||
template<class T, void (*reduce)(T*,T*,size_t)>
|
||||
class AllreduceAllpairsBuilder {
|
||||
AllreduceAllpairs<T, reduce> d;
|
||||
std::vector<mscclppDevConn_t> hostConns;
|
||||
|
||||
public:
|
||||
|
||||
// The constructor is called after the user has allocated the buffer to be allreduced
|
||||
AllreduceAllpairsBuilder(T* data, size_t size) {
|
||||
d.userData = data;
|
||||
d.userSize = size;
|
||||
d.scratch = nullptr;
|
||||
d.connFlags = nullptr;
|
||||
d.conns = nullptr;
|
||||
d.barrier = nullptr;
|
||||
}
|
||||
|
||||
// connect is called after rank initialization but before connection setup
|
||||
mscclppResult_t connect(mscclppComm_t comm) {
|
||||
MSCCLPPCHECK(mscclppCommRank(comm, &d.rank));
|
||||
MSCCLPPCHECK(mscclppCommSize(comm, &d.nRanks));
|
||||
|
||||
Volume myChunks = chunkVolume(d.userSize, d.nRanks, d.rank, 1);
|
||||
d.scratchSize = myChunks.size * d.numPeers();
|
||||
|
||||
CUDACHECK(cudaMalloc(&d.scratch, d.scratchSize * sizeof(T)));
|
||||
CUDACHECK(cudaMalloc(&d.connFlags, 3 * sizeof(uint64_t)));
|
||||
CUDACHECK(cudaMemset(d.connFlags, 0, 3 * sizeof(uint64_t)));
|
||||
|
||||
hostConns.resize(d.numPeers() * 3);
|
||||
for (int peer = 0; peer < d.nRanks; ++peer) {
|
||||
if (peer != d.rank) {
|
||||
int sendTag = d.rank < peer ? 0 : 1;
|
||||
int recvTag = d.rank < peer ? 1 : 0;
|
||||
MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase1SendConnIdx(peer), peer, d.userData, d.userSize * sizeof(T), d.connFlags + 0, sendTag, mscclppTransportP2P, nullptr));
|
||||
MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase1RecvConnIdx(peer), peer, d.scratch, d.scratchSize * sizeof(T), d.connFlags + 1, recvTag, mscclppTransportP2P, nullptr));
|
||||
MSCCLPPCHECK(mscclppConnect(comm, hostConns.data() + d.phase2ConnIdx(peer), peer, d.userData, d.userSize * sizeof(T), d.connFlags + 2, 2, mscclppTransportP2P, nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
// finishSetup is called after connection setup and returns an algorithm object that is ready to be passed to a GPU kernel
|
||||
AllreduceAllpairs<T, reduce> finishSetup() {
|
||||
CUDACHECK(cudaMalloc(&d.conns, hostConns.size() * sizeof(mscclppDevConn_t)));
|
||||
CUDACHECK(cudaMemcpy(d.conns, hostConns.data(), hostConns.size() * sizeof(mscclppDevConn_t), cudaMemcpyHostToDevice));
|
||||
CUDACHECK(cudaMalloc(&d.barrier, sizeof(cuda::barrier<cuda::thread_scope_device>)));
|
||||
cuda::barrier<cuda::thread_scope_device> initBarrier(d.numBlocks());
|
||||
CUDACHECK(cudaMemcpy(d.barrier, &initBarrier, sizeof(cuda::barrier<cuda::thread_scope_device>), cudaMemcpyHostToDevice));
|
||||
return d;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
__device__ void reduceSum(T* dst, T* src, size_t size) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
dst[i] += src[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
__global__ void init(T* data, size_t size, int rank) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
data[i] = rank;
|
||||
}
|
||||
}
|
||||
|
||||
// The main test kernel
|
||||
template<class T>
|
||||
__global__ void testKernel(AllreduceAllpairs<T, reduceSum> d) {
|
||||
d.run(blockIdx.x);
|
||||
}
|
||||
|
||||
int main(int argc, const char *argv[]) {
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
MPI_Init(NULL, NULL);
|
||||
#endif
|
||||
const char* ip_port;
|
||||
int rank, world_size;
|
||||
parse_arguments(argc, argv, &ip_port, &rank, &world_size);
|
||||
|
||||
CUDACHECK(cudaSetDevice(rank));
|
||||
|
||||
// Allocate and initialize 1 MB of data
|
||||
int *data;
|
||||
size_t dataSize = 1024 * 1024 / sizeof(int);
|
||||
CUDACHECK(cudaMalloc(&data, dataSize * sizeof(int)));
|
||||
init<<<1, 256>>>(data, dataSize, rank);
|
||||
|
||||
// Create the collective
|
||||
AllreduceAllpairsBuilder<int, reduceSum> builder(data, dataSize);
|
||||
|
||||
// Create the communicator
|
||||
mscclppComm_t comm;
|
||||
MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port));
|
||||
|
||||
// Connect the collective
|
||||
builder.connect(comm);
|
||||
|
||||
// Finish the setup
|
||||
MSCCLPPCHECK(mscclppConnectionSetup(comm));
|
||||
MSCCLPPCHECK(mscclppProxyLaunch(comm));
|
||||
auto allreduce = builder.finishSetup();
|
||||
|
||||
// Run the collective
|
||||
testKernel<<<allreduce.numBlocks(), 256>>>(allreduce);
|
||||
|
||||
// Wait for kernel to finish
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Check the result
|
||||
int* hostData = new int[dataSize];
|
||||
CUDACHECK(cudaMemcpy(hostData, data, dataSize * sizeof(int), cudaMemcpyDeviceToHost));
|
||||
int expectedValue = world_size * (world_size - 1) / 2;
|
||||
for (size_t i = 0; i < dataSize; ++i) {
|
||||
if (hostData[i] != expectedValue) {
|
||||
printf("Error at index %lu: %d != %d\n", i, hostData[i], expectedValue);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPPCHECK(mscclppProxyStop(comm));
|
||||
|
||||
MSCCLPPCHECK(mscclppCommDestroy(comm));
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
if (argc == 2) {
|
||||
MPI_Finalize();
|
||||
}
|
||||
#endif
|
||||
printf("Succeeded! %d\n", rank);
|
||||
return 0;
|
||||
}
|
||||
45
tests/common.h
Normal file
45
tests/common.h
Normal file
@@ -0,0 +1,45 @@
|
||||
#ifndef MSCCLPP_TESTS_COMMON_H_
|
||||
#define MSCCLPP_TESTS_COMMON_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include "mpi.h"
|
||||
#endif // MSCCLPP_USE_MPI_FOR_TESTS
|
||||
|
||||
void print_usage(const char *prog)
|
||||
{
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
printf("usage: %s IP:PORT [rank nranks]\n", prog);
|
||||
#else
|
||||
printf("usage: %s IP:PORT rank nranks\n", prog);
|
||||
#endif
|
||||
}
|
||||
|
||||
void parse_arguments(int argc, const char *argv[], const char** ip_port, int* rank, int* world_size) {
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
if (argc != 2 && argc != 4) {
|
||||
print_usage(argv[0]);
|
||||
exit(-1);
|
||||
}
|
||||
*ip_port = argv[1];
|
||||
if (argc == 4) {
|
||||
*rank = atoi(argv[2]);
|
||||
*world_size = atoi(argv[3]);
|
||||
} else {
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, world_size);
|
||||
}
|
||||
#else
|
||||
if (argc != 4) {
|
||||
print_usage(argv[0]);
|
||||
exit(-1);
|
||||
}
|
||||
const char *ip_port = argv[1];
|
||||
*rank = atoi(argv[2]);
|
||||
*world_size = atoi(argv[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // MSCCLPP_TESTS_COMMON_H_
|
||||
@@ -1,12 +1,11 @@
|
||||
#include "mscclpp.h"
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include "mpi.h"
|
||||
#endif // MSCCLPP_USE_MPI_FOR_TESTS
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define RANKS_PER_NODE 8
|
||||
#define USE_DMA_FOR_P2P 1
|
||||
#define TEST_CONN_TYPE 0 // 0: P2P(for local)+IB(for remote), 1: IB-Only
|
||||
@@ -147,42 +146,14 @@ int cudaNumToIbNum(int cudaNum)
|
||||
return ibNum;
|
||||
}
|
||||
|
||||
void print_usage(const char *prog)
|
||||
{
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
printf("usage: %s IP:PORT [rank nranks]\n", prog);
|
||||
#else
|
||||
printf("usage: %s IP:PORT rank nranks\n", prog);
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, const char *argv[])
|
||||
{
|
||||
#ifdef MSCCLPP_USE_MPI_FOR_TESTS
|
||||
if (argc != 2 && argc != 4) {
|
||||
print_usage(argv[0]);
|
||||
return -1;
|
||||
}
|
||||
const char *ip_port = argv[1];
|
||||
int rank;
|
||||
int world_size;
|
||||
if (argc == 4) {
|
||||
rank = atoi(argv[2]);
|
||||
world_size = atoi(argv[3]);
|
||||
} else {
|
||||
MPI_Init(NULL, NULL);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
}
|
||||
#else
|
||||
if (argc != 4) {
|
||||
print_usage(argv[0]);
|
||||
return -1;
|
||||
}
|
||||
const char *ip_port = argv[1];
|
||||
int rank = atoi(argv[2]);
|
||||
int world_size = atoi(argv[3]);
|
||||
MPI_Init(NULL, NULL);
|
||||
#endif
|
||||
const char* ip_port;
|
||||
int rank, world_size;
|
||||
parse_arguments(argc, argv, &ip_port, &rank, &world_size);
|
||||
int localRank = rankToLocalRank(rank);
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = localRank;
|
||||
|
||||
Reference in New Issue
Block a user