diff --git a/src/bootstrap/ib.cc b/src/bootstrap/ib.cc index b8defcf7..f266b3a4 100644 --- a/src/bootstrap/ib.cc +++ b/src/bootstrap/ib.cc @@ -298,8 +298,8 @@ int mscclppIbQp::rts() IBV_QP_MAX_QP_RD_ATOMIC); } -int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size, - uint64_t wrId, unsigned int immData, int offset) +int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, + uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled) { if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { return -1; @@ -314,11 +314,11 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info wr_->num_sge = 1; wr_->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wr_->imm_data = immData; - wr_->send_flags = IBV_SEND_SIGNALED; - wr_->wr.rdma.remote_addr = info->addr; + wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; + wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + offset; wr_->wr.rdma.rkey = info->rkey; wr_->next = nullptr; - sge_->addr = (uint64_t)(ibMr->buff) + (uint64_t)offset; + sge_->addr = (uint64_t)(ibMr->buff) + offset; sge_->length = size; sge_->lkey = ibMr->mr->lkey; if (wrn > 0) { diff --git a/src/bootstrap/init.cc b/src/bootstrap/init.cc index 4a2d60a6..95cf0551 100644 --- a/src/bootstrap/init.cc +++ b/src/bootstrap/init.cc @@ -141,105 +141,112 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){ return mscclppSuccess; } -MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int rankRecv, int rankSend, +MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize, int *flag, int tag, mscclppTransport_t transportType, const char *ibDev); -mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize, +mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize, int *flag, int tag, mscclppTransport_t transportType, const char *ibDev/*=NULL*/) { - if (comm->rank == rankRecv || comm->rank == rankSend) { - struct mscclppConn *conn = &comm->conns[comm->nConns++]; - conn->transport = transportType; - conn->rankSend = rankSend; - conn->rankRecv = rankRecv; - conn->tag = tag; - conn->buff = buff; - conn->buffSize = buffSize; - conn->flag = flag; - conn->ibCtx = NULL; - conn->ibQp = NULL; + struct mscclppConn *conn = &comm->conns[comm->nConns++]; + conn->transport = transportType; + conn->remoteRank = remoteRank; + conn->tag = tag; + conn->buff = buff; + conn->buffSize = buffSize; + conn->flag = flag; + conn->ibCtx = NULL; + conn->ibQp = NULL; - if (ibDev != NULL) { - // Check if an IB context exists - int ibDevIdx = -1; - int firstNullIdx = -1; - for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { - if (comm->ibContext[i] == NULL) { - if (firstNullIdx == -1) { - firstNullIdx = i; - } - } else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) { - ibDevIdx = i; - break; - } - } - if (ibDevIdx == -1) { - // Create a new context. + if (ibDev != NULL) { + // Check if an IB context exists + int ibDevIdx = -1; + int firstNullIdx = -1; + for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { + if (comm->ibContext[i] == NULL) { if (firstNullIdx == -1) { - WARN("Too many IB devices"); - return mscclppInvalidUsage; - } - ibDevIdx = firstNullIdx; - if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) { - WARN("Failed to create IB context"); - return mscclppInternalError; + firstNullIdx = i; } + } else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) { + ibDevIdx = i; + break; } - conn->ibCtx = comm->ibContext[ibDevIdx]; } + if (ibDevIdx == -1) { + // Create a new context. + if (firstNullIdx == -1) { + WARN("Too many IB devices"); + return mscclppInvalidUsage; + } + ibDevIdx = firstNullIdx; + if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) { + WARN("Failed to create IB context"); + return mscclppInternalError; + } + } + conn->ibCtx = comm->ibContext[ibDevIdx]; } return mscclppSuccess; } +struct connInfo { + cudaIpcMemHandle_t handleBuff; + cudaIpcMemHandle_t handleFlag; + mscclppIbQpInfo infoQp; + mscclppIbMrInfo infoBuffMr; + mscclppIbMrInfo infoLocalFlagMr; + mscclppIbMrInfo infoRemoteFlagMr; +}; + MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm); mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) { - struct connInfo { - cudaIpcMemHandle_t handleBuff; - cudaIpcMemHandle_t handleFlag; - mscclppIbQpInfo qpInfo; - mscclppIbMrInfo mrInfo; - }; + // Allocate connection info to be shared with GPU + MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns)); // Send info to peers for (int i = 0; i < comm->nConns; ++i) { struct mscclppConn *conn = &comm->conns[i]; + struct mscclppDevConn *devConn = &comm->devConns[i]; + conn->devConn = devConn; + devConn->tag = conn->tag; + devConn->localBuff = conn->buff; + devConn->localFlag = conn->flag; + MSCCLPPCHECK(mscclppCudaHostCalloc(&devConn->trigger, 1)); + struct connInfo cInfo; if (conn->transport == mscclppTransportP2P) { - CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, conn->buff)); - CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, conn->flag)); + CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, devConn->localBuff)); + CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, devConn->localFlag)); } else if (conn->transport == mscclppTransportIB) { + devConn->remoteBuff = NULL; + MSCCLPPCHECK(mscclppCudaCalloc(&devConn->remoteFlag, 1)); + struct mscclppIbContext *ibCtx = conn->ibCtx; if (conn->ibQp == NULL) { MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp)); } - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, conn->buff, conn->buffSize, &conn->ibMr)); - cInfo.qpInfo = conn->ibQp->info; - cInfo.mrInfo = conn->ibMr->info; + MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr)); + MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localFlag, sizeof(int), &conn->ibLocalFlagMr)); + MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->remoteFlag, sizeof(int), &conn->ibRemoteFlagMr)); + cInfo.infoQp = conn->ibQp->info; + cInfo.infoBuffMr = conn->ibBuffMr->info; + cInfo.infoLocalFlagMr = conn->ibLocalFlagMr->info; + cInfo.infoRemoteFlagMr = conn->ibRemoteFlagMr->info; } - int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend; - MSCCLPPCHECK(bootstrapSend(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo))); + MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo))); } - // Allocate connection info to be shared with GPU - MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns)); - // Recv info from peers for (int i = 0; i < comm->nConns; ++i) { struct mscclppConn *conn = &comm->conns[i]; struct mscclppDevConn *devConn = &comm->devConns[i]; - devConn->tag = conn->tag; - devConn->localBuff = conn->buff; - devConn->localFlag = conn->flag; - struct connInfo cInfo; - int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend; - MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo))); + MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo))); if (conn->transport == mscclppTransportP2P) { CUDACHECK(cudaIpcOpenMemHandle(&devConn->remoteBuff, cInfo.handleBuff, cudaIpcMemLazyEnablePeerAccess)); CUDACHECK(cudaIpcOpenMemHandle((void **)&devConn->remoteFlag, cInfo.handleFlag, cudaIpcMemLazyEnablePeerAccess)); } else if (conn->transport == mscclppTransportIB) { - if (conn->ibQp->rtr(&cInfo.qpInfo) != 0) { + if (conn->ibQp->rtr(&cInfo.infoQp) != 0) { WARN("Failed to transition QP to RTR"); return mscclppInvalidUsage; } @@ -247,9 +254,9 @@ mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) WARN("Failed to transition QP to RTS"); return mscclppInvalidUsage; } - conn->ibRemoteMrInfo = cInfo.mrInfo; - devConn->remoteBuff = NULL; - CUDACHECK(cudaMalloc(&devConn->remoteFlag, sizeof(int))); + conn->ibBuffMrInfo = cInfo.infoBuffMr; + conn->ibLocalFlagMrInfo = cInfo.infoLocalFlagMr; + conn->ibRemoteFlagMrInfo = cInfo.infoRemoteFlagMr; } } diff --git a/src/bootstrap/proxy.cc b/src/bootstrap/proxy.cc index cee63587..56e9467d 100644 --- a/src/bootstrap/proxy.cc +++ b/src/bootstrap/proxy.cc @@ -8,6 +8,8 @@ #include #include +#define MSCCLPP_PROXY_FLAG_SET_BY_RDMA 1 + struct proxyArgs { struct mscclppComm* comm; struct mscclppIbContext* ibCtx; @@ -27,45 +29,47 @@ void* mscclppProxyService(void* _args) { }; int rank = comm->rank; - std::map recvTagToConn; - std::map sendTagToConn; - std::map sendConnToState; + std::map qpNumToConn; + std::map> trigToSendStateAndConn; for (int i = 0; i < comm->nConns; ++i) { struct mscclppConn *conn = &comm->conns[i]; if (conn->transport != mscclppTransportIB) continue; if (conn->ibCtx != ibCtx) continue; - if (conn->rankRecv == rank) { - recvTagToConn[conn->tag] = conn; - } else if (conn->rankSend == rank) { - sendTagToConn[conn->tag] = conn; - sendConnToState[conn] = SEND_STATE_INIT; - } - } - // Initial post recv - for (auto &pair : recvTagToConn) { - struct mscclppConn *conn = pair.second; - int tag = pair.first; - if (conn->ibQp->postRecv((uint64_t)-tag) != 0) { + volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger; + trigToSendStateAndConn[tmp].first = SEND_STATE_INIT; + trigToSendStateAndConn[tmp].second = conn; + qpNumToConn[conn->ibQp->qp->qp_num] = conn; + // All connections may read + if (conn->ibQp->postRecv(0) != 0) { WARN("postRecv failed: errno %d", errno); } } // TODO(chhwang): run send and recv in different threads for lower latency + mscclppTrigger trigger; int wcNum; while (*stop == 0) { // Try send - for (auto &pair : sendConnToState) { - if (pair.second == SEND_STATE_INPROGRESS) continue; - // TODO(chhwang): do we need a thread per flag? - struct mscclppConn *conn = pair.first; - volatile int *flag = (volatile int *)conn->flag; - if (*flag == 0) continue; + // TODO(chhwang): one thread per conn + for (auto &pair : trigToSendStateAndConn) { + if (pair.second.first != SEND_STATE_INIT) continue; + trigger.value = *pair.first; + if (trigger.value == 0) continue; // Do send - conn->ibQp->stageSend(conn->ibMr, &conn->ibRemoteMrInfo, conn->buffSize, - (uint64_t)conn->tag, (unsigned int)conn->tag); + struct mscclppConn *conn = pair.second.second; +#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 1) + conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, + /*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false); + // My local flag is copied to the peer's remote flag + conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibRemoteFlagMrInfo, sizeof(int), + /*wrId=*/0, /*immData=*/0, /*offset=*/0, /*signaled=*/true); +#else + conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize, + /*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/true); +#endif if (conn->ibQp->postSend() != 0) { WARN("postSend failed: errno %d", errno); } - pair.second = SEND_STATE_INPROGRESS; + pair.second.first = SEND_STATE_INPROGRESS; } // Poll completions @@ -74,32 +78,26 @@ void* mscclppProxyService(void* _args) { for (int i = 0; i < wcNum; ++i) { struct ibv_wc *wc = &ibCtx->wcs[i]; if (wc->status != IBV_WC_SUCCESS) { - WARN("wc status %d", wc->status); + WARN("rank %d wc status %d", rank, wc->status); + continue; } - if (((int)wc->wr_id) < 0) { - // recv - auto search = recvTagToConn.find(wc->imm_data); - if (search == recvTagToConn.end()) { - WARN("unexpected imm_data %d", wc->imm_data); - } - struct mscclppConn *conn = search->second; - if (conn->ibQp->postRecv((uint64_t)-wc->imm_data) != 0) { + struct mscclppConn *conn = qpNumToConn[wc->qp_num]; + if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + // recv completion + if (qpNumToConn[wc->qp_num]->ibQp->postRecv(wc->wr_id) != 0) { WARN("postRecv failed: errno %d", errno); } - volatile int *flag = (volatile int *)conn->flag; - *flag = 1; - } else { - // send - int tag = (int)wc->wr_id; - auto search = sendTagToConn.find(tag); - if (search == sendTagToConn.end()) { - WARN("unexpected tag %d", tag); - } - struct mscclppConn *conn = search->second; - volatile int *flag = (volatile int *)conn->flag; - *flag = 0; - sendConnToState[conn] = SEND_STATE_INIT; - // WARN("send done rank %d", rank); +#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA != 1) + // TODO(chhwang): gdc & cpu flush + // *((volatile int *)conn->devConn->remoteFlag) = 1; +#endif + // WARN("rank %d recv completion", rank); + } else if (wc->opcode == IBV_WC_RDMA_WRITE) { + // send completion + volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger; + *tmp = 0; + trigToSendStateAndConn[tmp].first = SEND_STATE_INIT; + // WARN("rank %d send completion", rank); } } } diff --git a/src/include/comm.h b/src/include/comm.h index 1bbb9491..df5270f9 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -160,8 +160,7 @@ struct mscclppConn { mscclppTransport_t transport; - int rankSend; - int rankRecv; + int remoteRank; int tag; void* buff; int buffSize; @@ -169,8 +168,12 @@ struct mscclppConn { struct mscclppDevConn *devConn; struct mscclppIbContext *ibCtx; struct mscclppIbQp *ibQp; - struct mscclppIbMr *ibMr; - struct mscclppIbMrInfo ibRemoteMrInfo; + struct mscclppIbMr *ibBuffMr; + struct mscclppIbMr *ibLocalFlagMr; + struct mscclppIbMr *ibRemoteFlagMr; + struct mscclppIbMrInfo ibBuffMrInfo; + struct mscclppIbMrInfo ibLocalFlagMrInfo; + struct mscclppIbMrInfo ibRemoteFlagMrInfo; }; struct mscclppComm { diff --git a/src/include/ib.h b/src/include/ib.h index 77212fa3..04bb3252 100644 --- a/src/include/ib.h +++ b/src/include/ib.h @@ -45,8 +45,8 @@ struct mscclppIbQp { int rtr(const mscclppIbQpInfo *info); int rts(); - int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size, - uint64_t wrId, unsigned int immData, int offset = 0); + int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size, + uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled); int postSend(); int postRecv(uint64_t wrId); }; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 3d338fdf..a7d2fbf8 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -6,6 +6,7 @@ #if CUDART_VERSION >= 11000 #include #endif +#include #define MSCCLPP_MAJOR 0 #define MSCCLPP_MINOR 1 @@ -16,6 +17,14 @@ extern "C" { #endif +union alignas(8) mscclppTrigger { + uint64_t value; + struct { + uint64_t dataSize : 32; + uint64_t dataOffset : 32; + } fields; +}; + struct mscclppDevConn { int tag; @@ -32,6 +41,8 @@ struct mscclppDevConn { // virtual void pullRmoteFlag(); // // localBuff[srcOffset..srcOffset+size-1] <- remoteBuff[dstOffset..dstOffset+size-1] // virtual void pullRemoteBuff(size_t srcOffset, size_t dstOffset, size_t size); + + mscclppTrigger* trigger; }; typedef struct mscclppComm* mscclppComm_t; @@ -102,8 +113,8 @@ mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int si mscclppResult_t mscclppCommDestroy(mscclppComm_t comm); -mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize, int *flag, int tag, - mscclppTransport_t transportType, const char *ibDev=NULL); +mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize, int *flag, + int tag, mscclppTransport_t transportType, const char *ibDev=NULL); mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm); diff --git a/tests/p2p_test.cu b/tests/p2p_test.cu index e4b15e8c..80620a99 100644 --- a/tests/p2p_test.cu +++ b/tests/p2p_test.cu @@ -14,75 +14,51 @@ #define CUDACHECK(cmd) do { \ cudaError_t err = cmd; \ if( err != cudaSuccess ) { \ - printf("Cuda failure '%s'", cudaGetErrorString(err)); \ + printf("%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \ exit(EXIT_FAILURE); \ } \ } while(false) __global__ void kernel(mscclppDevConn_t devConns, int rank, int world_size) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid == 0) { - // Get sending data and send flag - volatile int *data; - for (int i = 0; i < (world_size - 1) * 2; ++i) { - mscclppDevConn_t devConn = &devConns[i]; - int tag = devConn->tag; - int rankSend = tag % world_size; - if (rankSend == rank) { // I am a sender - data = (volatile int *)devConn->localBuff; - // We are sending the same data to all peers, so just break here - break; - } - } + int warpId = threadIdx.x / 32; + int remoteRank = (warpId < rank) ? warpId : warpId + 1; + mscclppDevConn_t devConn = &devConns[(remoteRank < rank) ? remoteRank : remoteRank - 1]; + volatile int *data = (volatile int *)devConn->localBuff; + volatile int *localFlag = devConn->localFlag; + volatile int *remoteFlag = devConn->remoteFlag; + volatile uint64_t *trig = (volatile uint64_t *)devConn->trigger; - // Set my data - *data = rank + 1; + if (threadIdx.x == 0) { + // Set my data and flag + *(data + rank) = rank + 1; + __threadfence_system(); + *localFlag = 1; + } + __syncthreads(); - // Set send flags to inform all peers that the data is ready - for (int i = 0; i < (world_size - 1) * 2; ++i) { - mscclppDevConn_t devConn = &devConns[i]; - int tag = devConn->tag; - int rankSend = tag % world_size; - if (rankSend == rank) { // I am a sender - *((volatile int *)devConn->localFlag) = 1; - } - } + // Each warp receives data from different ranks + if (threadIdx.x % 32 == 0) { + if (devConn->remoteBuff == NULL) { // IB + // Trigger sending data and flag + uint64_t dataOffset = rank * sizeof(int); + uint64_t dataSize = sizeof(int); + *trig = (dataOffset << 32) + dataSize; - // Read data from all other peers - for (int i = 0; i < (world_size - 1) * 2; ++i) { - mscclppDevConn_t devConn = &devConns[i]; - int tag = devConn->tag; - int rankSend = tag % world_size; - int rankRecv = tag / world_size; - if (rankRecv == rank) { // I am a receiver - if (devConn->remoteBuff == NULL) { // IB - volatile int *localFlag = (volatile int *)devConn->localFlag; + // Wait until the proxy have sent my data and flag + while (*trig != 0) {} - // Wait until the data comes in via proxy - while (*localFlag != 1) {} - } else { // P2P - volatile int *remoteData = (volatile int *)devConn->remoteBuff; - volatile int *remoteFlag = (volatile int *)devConn->remoteFlag; + // Wait for receiving data from remote rank + while (*remoteFlag != 1) {} + } else { // P2P + // Directly read data + volatile int *remoteData = (volatile int *)devConn->remoteBuff; - // Wait until the remote data is set - while (*remoteFlag != 1) {} + // Wait until the remote data is set + while (*remoteFlag != 1) {} - // Read remote data - data[rankSend] = remoteData[rankSend]; - } - } - } - - // Wait until the proxy have sent my data to all peers - for (int i = 0; i < (world_size - 1) * 2; ++i) { - mscclppDevConn_t devConn = &devConns[i]; - int tag = devConn->tag; - int rankSend = tag % world_size; - if (rankSend == rank) { // I am a sender - volatile int *flag = (volatile int *)devConn->localFlag; - while (*flag == 1) {} - } + // Read remote data + data[remoteRank] = remoteData[remoteRank]; } } } @@ -133,6 +109,8 @@ int main(int argc, const char *argv[]) int rank = atoi(argv[2]); int world_size = atoi(argv[3]); #endif + int localRank = rankToLocalRank(rank); + int thisNode = rankToNode(rank); mscclppComm_t comm; mscclppResult_t res = mscclppCommInitRank(&comm, world_size, rank, ip_port); @@ -141,64 +119,33 @@ int main(int argc, const char *argv[]) return -1; } + CUDACHECK(cudaSetDevice(localRank)); + int *data_d; - int *send_flags_d; - int *recv_flags_d; + int *flag_d; CUDACHECK(cudaMalloc(&data_d, sizeof(int) * world_size)); - CUDACHECK(cudaHostAlloc(&send_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped)); - CUDACHECK(cudaHostAlloc(&recv_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped)); - + CUDACHECK(cudaMalloc(&flag_d, sizeof(int))); CUDACHECK(cudaMemset(data_d, 0, sizeof(int) * world_size)); - // CUDACHECK(cudaMemcpy(data_d, tmp, sizeof(int) * 2, cudaMemcpyHostToDevice)); - // printf("rank %d CPU: setting data at %p\n", rank, data_d + rank); - memset(send_flags_d, 0, sizeof(int) * (world_size - 1)); - memset(recv_flags_d, 0, sizeof(int) * (world_size - 1)); + CUDACHECK(cudaMemset(flag_d, 0, sizeof(int))); - int localRank = rankToLocalRank(rank); - int thisNode = rankToNode(rank); - std::string ibDev = "mlx5_ib" + std::to_string(localRank); + std::string ibDevStr = "mlx5_ib" + std::to_string(localRank); - // Read from all other ranks - int idx = 0; for (int r = 0; r < world_size; ++r) { if (r == rank) continue; - int tag = rank * world_size + r; + mscclppTransport_t transportType = mscclppTransportIB; + const char *ibDev = ibDevStr.c_str(); #if (TEST_CONN_TYPE == 0) // P2P+IB - int node = rankToNode(r); - if (node == thisNode) { - res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportP2P); - } else { - res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str()); + if (rankToNode(r) == thisNode) { + transportType = mscclppTransportP2P; + ibDev = NULL; } -#else // (TEST_CONN_TYPE == 1) // IB-Only - res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str()); #endif + // Connect with all other ranks + res = mscclppConnect(comm, r, data_d, sizeof(int) * world_size, flag_d, 0, transportType, ibDev); if (res != mscclppSuccess) { printf("mscclppConnect failed\n"); return -1; } - ++idx; - } - // Let others read from me - idx = 0; - for (int r = 0; r < world_size; ++r) { - if (r == rank) continue; - int tag = r * world_size + rank; -#if (TEST_CONN_TYPE == 0) // P2P+IB - int node = rankToNode(r); - if (node == thisNode) { - res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportP2P); - } else { - res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str()); - } -#else // (TEST_CONN_TYPE == 1) // IB-Only - res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str()); -#endif - if (res != mscclppSuccess) { - printf("mscclppConnect failed\n"); - return -1; - } - ++idx; } res = mscclppConnectionSetup(comm); @@ -216,7 +163,7 @@ int main(int argc, const char *argv[]) mscclppDevConn_t devConns; mscclppGetDevConns(comm, &devConns); - kernel<<<1, 1>>>(devConns, rank, world_size); + kernel<<<1, 32 * (world_size - 1)>>>(devConns, rank, world_size); CUDACHECK(cudaDeviceSynchronize()); res = mscclppProxyStop(comm);