Bidirectional connection

This commit is contained in:
Changho Hwang
2023-02-22 06:06:14 +00:00
parent 33e20aceb9
commit 91e04a527b
7 changed files with 191 additions and 225 deletions

View File

@@ -298,8 +298,8 @@ int mscclppIbQp::rts()
IBV_QP_MAX_QP_RD_ATOMIC);
}
int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size,
uint64_t wrId, unsigned int immData, int offset)
int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled)
{
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
return -1;
@@ -314,11 +314,11 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info
wr_->num_sge = 1;
wr_->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr_->imm_data = immData;
wr_->send_flags = IBV_SEND_SIGNALED;
wr_->wr.rdma.remote_addr = info->addr;
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + offset;
wr_->wr.rdma.rkey = info->rkey;
wr_->next = nullptr;
sge_->addr = (uint64_t)(ibMr->buff) + (uint64_t)offset;
sge_->addr = (uint64_t)(ibMr->buff) + offset;
sge_->length = size;
sge_->lkey = ibMr->mr->lkey;
if (wrn > 0) {

View File

@@ -141,105 +141,112 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){
return mscclppSuccess;
}
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int rankRecv, int rankSend,
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank,
void *buff, size_t buffSize, int *flag, int tag, mscclppTransport_t transportType, const char *ibDev);
mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize,
mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize,
int *flag, int tag, mscclppTransport_t transportType, const char *ibDev/*=NULL*/)
{
if (comm->rank == rankRecv || comm->rank == rankSend) {
struct mscclppConn *conn = &comm->conns[comm->nConns++];
conn->transport = transportType;
conn->rankSend = rankSend;
conn->rankRecv = rankRecv;
conn->tag = tag;
conn->buff = buff;
conn->buffSize = buffSize;
conn->flag = flag;
conn->ibCtx = NULL;
conn->ibQp = NULL;
struct mscclppConn *conn = &comm->conns[comm->nConns++];
conn->transport = transportType;
conn->remoteRank = remoteRank;
conn->tag = tag;
conn->buff = buff;
conn->buffSize = buffSize;
conn->flag = flag;
conn->ibCtx = NULL;
conn->ibQp = NULL;
if (ibDev != NULL) {
// Check if an IB context exists
int ibDevIdx = -1;
int firstNullIdx = -1;
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
if (comm->ibContext[i] == NULL) {
if (firstNullIdx == -1) {
firstNullIdx = i;
}
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
ibDevIdx = i;
break;
}
}
if (ibDevIdx == -1) {
// Create a new context.
if (ibDev != NULL) {
// Check if an IB context exists
int ibDevIdx = -1;
int firstNullIdx = -1;
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
if (comm->ibContext[i] == NULL) {
if (firstNullIdx == -1) {
WARN("Too many IB devices");
return mscclppInvalidUsage;
}
ibDevIdx = firstNullIdx;
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
WARN("Failed to create IB context");
return mscclppInternalError;
firstNullIdx = i;
}
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
ibDevIdx = i;
break;
}
conn->ibCtx = comm->ibContext[ibDevIdx];
}
if (ibDevIdx == -1) {
// Create a new context.
if (firstNullIdx == -1) {
WARN("Too many IB devices");
return mscclppInvalidUsage;
}
ibDevIdx = firstNullIdx;
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
WARN("Failed to create IB context");
return mscclppInternalError;
}
}
conn->ibCtx = comm->ibContext[ibDevIdx];
}
return mscclppSuccess;
}
struct connInfo {
cudaIpcMemHandle_t handleBuff;
cudaIpcMemHandle_t handleFlag;
mscclppIbQpInfo infoQp;
mscclppIbMrInfo infoBuffMr;
mscclppIbMrInfo infoLocalFlagMr;
mscclppIbMrInfo infoRemoteFlagMr;
};
MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm);
mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
{
struct connInfo {
cudaIpcMemHandle_t handleBuff;
cudaIpcMemHandle_t handleFlag;
mscclppIbQpInfo qpInfo;
mscclppIbMrInfo mrInfo;
};
// Allocate connection info to be shared with GPU
MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns));
// Send info to peers
for (int i = 0; i < comm->nConns; ++i) {
struct mscclppConn *conn = &comm->conns[i];
struct mscclppDevConn *devConn = &comm->devConns[i];
conn->devConn = devConn;
devConn->tag = conn->tag;
devConn->localBuff = conn->buff;
devConn->localFlag = conn->flag;
MSCCLPPCHECK(mscclppCudaHostCalloc(&devConn->trigger, 1));
struct connInfo cInfo;
if (conn->transport == mscclppTransportP2P) {
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, conn->buff));
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, conn->flag));
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, devConn->localBuff));
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, devConn->localFlag));
} else if (conn->transport == mscclppTransportIB) {
devConn->remoteBuff = NULL;
MSCCLPPCHECK(mscclppCudaCalloc(&devConn->remoteFlag, 1));
struct mscclppIbContext *ibCtx = conn->ibCtx;
if (conn->ibQp == NULL) {
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp));
}
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, conn->buff, conn->buffSize, &conn->ibMr));
cInfo.qpInfo = conn->ibQp->info;
cInfo.mrInfo = conn->ibMr->info;
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localFlag, sizeof(int), &conn->ibLocalFlagMr));
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->remoteFlag, sizeof(int), &conn->ibRemoteFlagMr));
cInfo.infoQp = conn->ibQp->info;
cInfo.infoBuffMr = conn->ibBuffMr->info;
cInfo.infoLocalFlagMr = conn->ibLocalFlagMr->info;
cInfo.infoRemoteFlagMr = conn->ibRemoteFlagMr->info;
}
int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend;
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo)));
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo)));
}
// Allocate connection info to be shared with GPU
MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns));
// Recv info from peers
for (int i = 0; i < comm->nConns; ++i) {
struct mscclppConn *conn = &comm->conns[i];
struct mscclppDevConn *devConn = &comm->devConns[i];
devConn->tag = conn->tag;
devConn->localBuff = conn->buff;
devConn->localFlag = conn->flag;
struct connInfo cInfo;
int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend;
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo)));
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo)));
if (conn->transport == mscclppTransportP2P) {
CUDACHECK(cudaIpcOpenMemHandle(&devConn->remoteBuff, cInfo.handleBuff, cudaIpcMemLazyEnablePeerAccess));
CUDACHECK(cudaIpcOpenMemHandle((void **)&devConn->remoteFlag, cInfo.handleFlag, cudaIpcMemLazyEnablePeerAccess));
} else if (conn->transport == mscclppTransportIB) {
if (conn->ibQp->rtr(&cInfo.qpInfo) != 0) {
if (conn->ibQp->rtr(&cInfo.infoQp) != 0) {
WARN("Failed to transition QP to RTR");
return mscclppInvalidUsage;
}
@@ -247,9 +254,9 @@ mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
WARN("Failed to transition QP to RTS");
return mscclppInvalidUsage;
}
conn->ibRemoteMrInfo = cInfo.mrInfo;
devConn->remoteBuff = NULL;
CUDACHECK(cudaMalloc(&devConn->remoteFlag, sizeof(int)));
conn->ibBuffMrInfo = cInfo.infoBuffMr;
conn->ibLocalFlagMrInfo = cInfo.infoLocalFlagMr;
conn->ibRemoteFlagMrInfo = cInfo.infoRemoteFlagMr;
}
}

View File

@@ -8,6 +8,8 @@
#include <sys/syscall.h>
#include <map>
#define MSCCLPP_PROXY_FLAG_SET_BY_RDMA 1
struct proxyArgs {
struct mscclppComm* comm;
struct mscclppIbContext* ibCtx;
@@ -27,45 +29,47 @@ void* mscclppProxyService(void* _args) {
};
int rank = comm->rank;
std::map<int, struct mscclppConn *> recvTagToConn;
std::map<int, struct mscclppConn *> sendTagToConn;
std::map<struct mscclppConn *, int> sendConnToState;
std::map<uint32_t, struct mscclppConn *> qpNumToConn;
std::map<volatile uint64_t *, std::pair<int, struct mscclppConn *>> trigToSendStateAndConn;
for (int i = 0; i < comm->nConns; ++i) {
struct mscclppConn *conn = &comm->conns[i];
if (conn->transport != mscclppTransportIB) continue;
if (conn->ibCtx != ibCtx) continue;
if (conn->rankRecv == rank) {
recvTagToConn[conn->tag] = conn;
} else if (conn->rankSend == rank) {
sendTagToConn[conn->tag] = conn;
sendConnToState[conn] = SEND_STATE_INIT;
}
}
// Initial post recv
for (auto &pair : recvTagToConn) {
struct mscclppConn *conn = pair.second;
int tag = pair.first;
if (conn->ibQp->postRecv((uint64_t)-tag) != 0) {
volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger;
trigToSendStateAndConn[tmp].first = SEND_STATE_INIT;
trigToSendStateAndConn[tmp].second = conn;
qpNumToConn[conn->ibQp->qp->qp_num] = conn;
// All connections may read
if (conn->ibQp->postRecv(0) != 0) {
WARN("postRecv failed: errno %d", errno);
}
}
// TODO(chhwang): run send and recv in different threads for lower latency
mscclppTrigger trigger;
int wcNum;
while (*stop == 0) {
// Try send
for (auto &pair : sendConnToState) {
if (pair.second == SEND_STATE_INPROGRESS) continue;
// TODO(chhwang): do we need a thread per flag?
struct mscclppConn *conn = pair.first;
volatile int *flag = (volatile int *)conn->flag;
if (*flag == 0) continue;
// TODO(chhwang): one thread per conn
for (auto &pair : trigToSendStateAndConn) {
if (pair.second.first != SEND_STATE_INIT) continue;
trigger.value = *pair.first;
if (trigger.value == 0) continue;
// Do send
conn->ibQp->stageSend(conn->ibMr, &conn->ibRemoteMrInfo, conn->buffSize,
(uint64_t)conn->tag, (unsigned int)conn->tag);
struct mscclppConn *conn = pair.second.second;
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 1)
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize,
/*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false);
// My local flag is copied to the peer's remote flag
conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibRemoteFlagMrInfo, sizeof(int),
/*wrId=*/0, /*immData=*/0, /*offset=*/0, /*signaled=*/true);
#else
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize,
/*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/true);
#endif
if (conn->ibQp->postSend() != 0) {
WARN("postSend failed: errno %d", errno);
}
pair.second = SEND_STATE_INPROGRESS;
pair.second.first = SEND_STATE_INPROGRESS;
}
// Poll completions
@@ -74,32 +78,26 @@ void* mscclppProxyService(void* _args) {
for (int i = 0; i < wcNum; ++i) {
struct ibv_wc *wc = &ibCtx->wcs[i];
if (wc->status != IBV_WC_SUCCESS) {
WARN("wc status %d", wc->status);
WARN("rank %d wc status %d", rank, wc->status);
continue;
}
if (((int)wc->wr_id) < 0) {
// recv
auto search = recvTagToConn.find(wc->imm_data);
if (search == recvTagToConn.end()) {
WARN("unexpected imm_data %d", wc->imm_data);
}
struct mscclppConn *conn = search->second;
if (conn->ibQp->postRecv((uint64_t)-wc->imm_data) != 0) {
struct mscclppConn *conn = qpNumToConn[wc->qp_num];
if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
// recv completion
if (qpNumToConn[wc->qp_num]->ibQp->postRecv(wc->wr_id) != 0) {
WARN("postRecv failed: errno %d", errno);
}
volatile int *flag = (volatile int *)conn->flag;
*flag = 1;
} else {
// send
int tag = (int)wc->wr_id;
auto search = sendTagToConn.find(tag);
if (search == sendTagToConn.end()) {
WARN("unexpected tag %d", tag);
}
struct mscclppConn *conn = search->second;
volatile int *flag = (volatile int *)conn->flag;
*flag = 0;
sendConnToState[conn] = SEND_STATE_INIT;
// WARN("send done rank %d", rank);
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA != 1)
// TODO(chhwang): gdc & cpu flush
// *((volatile int *)conn->devConn->remoteFlag) = 1;
#endif
// WARN("rank %d recv completion", rank);
} else if (wc->opcode == IBV_WC_RDMA_WRITE) {
// send completion
volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger;
*tmp = 0;
trigToSendStateAndConn[tmp].first = SEND_STATE_INIT;
// WARN("rank %d send completion", rank);
}
}
}

View File

@@ -160,8 +160,7 @@
struct mscclppConn {
mscclppTransport_t transport;
int rankSend;
int rankRecv;
int remoteRank;
int tag;
void* buff;
int buffSize;
@@ -169,8 +168,12 @@ struct mscclppConn {
struct mscclppDevConn *devConn;
struct mscclppIbContext *ibCtx;
struct mscclppIbQp *ibQp;
struct mscclppIbMr *ibMr;
struct mscclppIbMrInfo ibRemoteMrInfo;
struct mscclppIbMr *ibBuffMr;
struct mscclppIbMr *ibLocalFlagMr;
struct mscclppIbMr *ibRemoteFlagMr;
struct mscclppIbMrInfo ibBuffMrInfo;
struct mscclppIbMrInfo ibLocalFlagMrInfo;
struct mscclppIbMrInfo ibRemoteFlagMrInfo;
};
struct mscclppComm {

View File

@@ -45,8 +45,8 @@ struct mscclppIbQp {
int rtr(const mscclppIbQpInfo *info);
int rts();
int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size,
uint64_t wrId, unsigned int immData, int offset = 0);
int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled);
int postSend();
int postRecv(uint64_t wrId);
};

View File

@@ -6,6 +6,7 @@
#if CUDART_VERSION >= 11000
#include <cuda_bf16.h>
#endif
#include <stdint.h>
#define MSCCLPP_MAJOR 0
#define MSCCLPP_MINOR 1
@@ -16,6 +17,14 @@
extern "C" {
#endif
union alignas(8) mscclppTrigger {
uint64_t value;
struct {
uint64_t dataSize : 32;
uint64_t dataOffset : 32;
} fields;
};
struct mscclppDevConn {
int tag;
@@ -32,6 +41,8 @@ struct mscclppDevConn {
// virtual void pullRmoteFlag();
// // localBuff[srcOffset..srcOffset+size-1] <- remoteBuff[dstOffset..dstOffset+size-1]
// virtual void pullRemoteBuff(size_t srcOffset, size_t dstOffset, size_t size);
mscclppTrigger* trigger;
};
typedef struct mscclppComm* mscclppComm_t;
@@ -102,8 +113,8 @@ mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int si
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize, int *flag, int tag,
mscclppTransport_t transportType, const char *ibDev=NULL);
mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize, int *flag,
int tag, mscclppTransport_t transportType, const char *ibDev=NULL);
mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm);

View File

@@ -14,75 +14,51 @@
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
printf("Cuda failure '%s'", cudaGetErrorString(err)); \
printf("%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while(false)
__global__ void kernel(mscclppDevConn_t devConns, int rank, int world_size)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid == 0) {
// Get sending data and send flag
volatile int *data;
for (int i = 0; i < (world_size - 1) * 2; ++i) {
mscclppDevConn_t devConn = &devConns[i];
int tag = devConn->tag;
int rankSend = tag % world_size;
if (rankSend == rank) { // I am a sender
data = (volatile int *)devConn->localBuff;
// We are sending the same data to all peers, so just break here
break;
}
}
int warpId = threadIdx.x / 32;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
mscclppDevConn_t devConn = &devConns[(remoteRank < rank) ? remoteRank : remoteRank - 1];
volatile int *data = (volatile int *)devConn->localBuff;
volatile int *localFlag = devConn->localFlag;
volatile int *remoteFlag = devConn->remoteFlag;
volatile uint64_t *trig = (volatile uint64_t *)devConn->trigger;
// Set my data
*data = rank + 1;
if (threadIdx.x == 0) {
// Set my data and flag
*(data + rank) = rank + 1;
__threadfence_system();
*localFlag = 1;
}
__syncthreads();
// Set send flags to inform all peers that the data is ready
for (int i = 0; i < (world_size - 1) * 2; ++i) {
mscclppDevConn_t devConn = &devConns[i];
int tag = devConn->tag;
int rankSend = tag % world_size;
if (rankSend == rank) { // I am a sender
*((volatile int *)devConn->localFlag) = 1;
}
}
// Each warp receives data from different ranks
if (threadIdx.x % 32 == 0) {
if (devConn->remoteBuff == NULL) { // IB
// Trigger sending data and flag
uint64_t dataOffset = rank * sizeof(int);
uint64_t dataSize = sizeof(int);
*trig = (dataOffset << 32) + dataSize;
// Read data from all other peers
for (int i = 0; i < (world_size - 1) * 2; ++i) {
mscclppDevConn_t devConn = &devConns[i];
int tag = devConn->tag;
int rankSend = tag % world_size;
int rankRecv = tag / world_size;
if (rankRecv == rank) { // I am a receiver
if (devConn->remoteBuff == NULL) { // IB
volatile int *localFlag = (volatile int *)devConn->localFlag;
// Wait until the proxy have sent my data and flag
while (*trig != 0) {}
// Wait until the data comes in via proxy
while (*localFlag != 1) {}
} else { // P2P
volatile int *remoteData = (volatile int *)devConn->remoteBuff;
volatile int *remoteFlag = (volatile int *)devConn->remoteFlag;
// Wait for receiving data from remote rank
while (*remoteFlag != 1) {}
} else { // P2P
// Directly read data
volatile int *remoteData = (volatile int *)devConn->remoteBuff;
// Wait until the remote data is set
while (*remoteFlag != 1) {}
// Wait until the remote data is set
while (*remoteFlag != 1) {}
// Read remote data
data[rankSend] = remoteData[rankSend];
}
}
}
// Wait until the proxy have sent my data to all peers
for (int i = 0; i < (world_size - 1) * 2; ++i) {
mscclppDevConn_t devConn = &devConns[i];
int tag = devConn->tag;
int rankSend = tag % world_size;
if (rankSend == rank) { // I am a sender
volatile int *flag = (volatile int *)devConn->localFlag;
while (*flag == 1) {}
}
// Read remote data
data[remoteRank] = remoteData[remoteRank];
}
}
}
@@ -133,6 +109,8 @@ int main(int argc, const char *argv[])
int rank = atoi(argv[2]);
int world_size = atoi(argv[3]);
#endif
int localRank = rankToLocalRank(rank);
int thisNode = rankToNode(rank);
mscclppComm_t comm;
mscclppResult_t res = mscclppCommInitRank(&comm, world_size, rank, ip_port);
@@ -141,64 +119,33 @@ int main(int argc, const char *argv[])
return -1;
}
CUDACHECK(cudaSetDevice(localRank));
int *data_d;
int *send_flags_d;
int *recv_flags_d;
int *flag_d;
CUDACHECK(cudaMalloc(&data_d, sizeof(int) * world_size));
CUDACHECK(cudaHostAlloc(&send_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped));
CUDACHECK(cudaHostAlloc(&recv_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped));
CUDACHECK(cudaMalloc(&flag_d, sizeof(int)));
CUDACHECK(cudaMemset(data_d, 0, sizeof(int) * world_size));
// CUDACHECK(cudaMemcpy(data_d, tmp, sizeof(int) * 2, cudaMemcpyHostToDevice));
// printf("rank %d CPU: setting data at %p\n", rank, data_d + rank);
memset(send_flags_d, 0, sizeof(int) * (world_size - 1));
memset(recv_flags_d, 0, sizeof(int) * (world_size - 1));
CUDACHECK(cudaMemset(flag_d, 0, sizeof(int)));
int localRank = rankToLocalRank(rank);
int thisNode = rankToNode(rank);
std::string ibDev = "mlx5_ib" + std::to_string(localRank);
std::string ibDevStr = "mlx5_ib" + std::to_string(localRank);
// Read from all other ranks
int idx = 0;
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
int tag = rank * world_size + r;
mscclppTransport_t transportType = mscclppTransportIB;
const char *ibDev = ibDevStr.c_str();
#if (TEST_CONN_TYPE == 0) // P2P+IB
int node = rankToNode(r);
if (node == thisNode) {
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportP2P);
} else {
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
if (rankToNode(r) == thisNode) {
transportType = mscclppTransportP2P;
ibDev = NULL;
}
#else // (TEST_CONN_TYPE == 1) // IB-Only
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
#endif
// Connect with all other ranks
res = mscclppConnect(comm, r, data_d, sizeof(int) * world_size, flag_d, 0, transportType, ibDev);
if (res != mscclppSuccess) {
printf("mscclppConnect failed\n");
return -1;
}
++idx;
}
// Let others read from me
idx = 0;
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
int tag = r * world_size + rank;
#if (TEST_CONN_TYPE == 0) // P2P+IB
int node = rankToNode(r);
if (node == thisNode) {
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportP2P);
} else {
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
}
#else // (TEST_CONN_TYPE == 1) // IB-Only
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
#endif
if (res != mscclppSuccess) {
printf("mscclppConnect failed\n");
return -1;
}
++idx;
}
res = mscclppConnectionSetup(comm);
@@ -216,7 +163,7 @@ int main(int argc, const char *argv[])
mscclppDevConn_t devConns;
mscclppGetDevConns(comm, &devConns);
kernel<<<1, 1>>>(devConns, rank, world_size);
kernel<<<1, 32 * (world_size - 1)>>>(devConns, rank, world_size);
CUDACHECK(cudaDeviceSynchronize());
res = mscclppProxyStop(comm);