mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-23 06:16:46 +00:00
Bidirectional connection
This commit is contained in:
@@ -298,8 +298,8 @@ int mscclppIbQp::rts()
|
||||
IBV_QP_MAX_QP_RD_ATOMIC);
|
||||
}
|
||||
|
||||
int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size,
|
||||
uint64_t wrId, unsigned int immData, int offset)
|
||||
int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled)
|
||||
{
|
||||
if (this->wrn >= MSCCLPP_IB_MAX_SENDS) {
|
||||
return -1;
|
||||
@@ -314,11 +314,11 @@ int mscclppIbQp::stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info
|
||||
wr_->num_sge = 1;
|
||||
wr_->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wr_->imm_data = immData;
|
||||
wr_->send_flags = IBV_SEND_SIGNALED;
|
||||
wr_->wr.rdma.remote_addr = info->addr;
|
||||
wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
|
||||
wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + offset;
|
||||
wr_->wr.rdma.rkey = info->rkey;
|
||||
wr_->next = nullptr;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + (uint64_t)offset;
|
||||
sge_->addr = (uint64_t)(ibMr->buff) + offset;
|
||||
sge_->length = size;
|
||||
sge_->lkey = ibMr->mr->lkey;
|
||||
if (wrn > 0) {
|
||||
|
||||
@@ -141,105 +141,112 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int rankRecv, int rankSend,
|
||||
MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, int remoteRank,
|
||||
void *buff, size_t buffSize, int *flag, int tag, mscclppTransport_t transportType, const char *ibDev);
|
||||
mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize,
|
||||
mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize,
|
||||
int *flag, int tag, mscclppTransport_t transportType, const char *ibDev/*=NULL*/)
|
||||
{
|
||||
if (comm->rank == rankRecv || comm->rank == rankSend) {
|
||||
struct mscclppConn *conn = &comm->conns[comm->nConns++];
|
||||
conn->transport = transportType;
|
||||
conn->rankSend = rankSend;
|
||||
conn->rankRecv = rankRecv;
|
||||
conn->tag = tag;
|
||||
conn->buff = buff;
|
||||
conn->buffSize = buffSize;
|
||||
conn->flag = flag;
|
||||
conn->ibCtx = NULL;
|
||||
conn->ibQp = NULL;
|
||||
struct mscclppConn *conn = &comm->conns[comm->nConns++];
|
||||
conn->transport = transportType;
|
||||
conn->remoteRank = remoteRank;
|
||||
conn->tag = tag;
|
||||
conn->buff = buff;
|
||||
conn->buffSize = buffSize;
|
||||
conn->flag = flag;
|
||||
conn->ibCtx = NULL;
|
||||
conn->ibQp = NULL;
|
||||
|
||||
if (ibDev != NULL) {
|
||||
// Check if an IB context exists
|
||||
int ibDevIdx = -1;
|
||||
int firstNullIdx = -1;
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i] == NULL) {
|
||||
if (firstNullIdx == -1) {
|
||||
firstNullIdx = i;
|
||||
}
|
||||
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
ibDevIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (ibDevIdx == -1) {
|
||||
// Create a new context.
|
||||
if (ibDev != NULL) {
|
||||
// Check if an IB context exists
|
||||
int ibDevIdx = -1;
|
||||
int firstNullIdx = -1;
|
||||
for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) {
|
||||
if (comm->ibContext[i] == NULL) {
|
||||
if (firstNullIdx == -1) {
|
||||
WARN("Too many IB devices");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
ibDevIdx = firstNullIdx;
|
||||
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
|
||||
WARN("Failed to create IB context");
|
||||
return mscclppInternalError;
|
||||
firstNullIdx = i;
|
||||
}
|
||||
} else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) {
|
||||
ibDevIdx = i;
|
||||
break;
|
||||
}
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx];
|
||||
}
|
||||
if (ibDevIdx == -1) {
|
||||
// Create a new context.
|
||||
if (firstNullIdx == -1) {
|
||||
WARN("Too many IB devices");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
ibDevIdx = firstNullIdx;
|
||||
if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) {
|
||||
WARN("Failed to create IB context");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
}
|
||||
conn->ibCtx = comm->ibContext[ibDevIdx];
|
||||
}
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
struct connInfo {
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
cudaIpcMemHandle_t handleFlag;
|
||||
mscclppIbQpInfo infoQp;
|
||||
mscclppIbMrInfo infoBuffMr;
|
||||
mscclppIbMrInfo infoLocalFlagMr;
|
||||
mscclppIbMrInfo infoRemoteFlagMr;
|
||||
};
|
||||
|
||||
MSCCLPP_API(mscclppResult_t, mscclppConnectionSetup, mscclppComm_t comm);
|
||||
mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
{
|
||||
struct connInfo {
|
||||
cudaIpcMemHandle_t handleBuff;
|
||||
cudaIpcMemHandle_t handleFlag;
|
||||
mscclppIbQpInfo qpInfo;
|
||||
mscclppIbMrInfo mrInfo;
|
||||
};
|
||||
// Allocate connection info to be shared with GPU
|
||||
MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns));
|
||||
|
||||
// Send info to peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn *conn = &comm->conns[i];
|
||||
struct mscclppDevConn *devConn = &comm->devConns[i];
|
||||
conn->devConn = devConn;
|
||||
devConn->tag = conn->tag;
|
||||
devConn->localBuff = conn->buff;
|
||||
devConn->localFlag = conn->flag;
|
||||
MSCCLPPCHECK(mscclppCudaHostCalloc(&devConn->trigger, 1));
|
||||
|
||||
struct connInfo cInfo;
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, conn->buff));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, conn->flag));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleBuff, devConn->localBuff));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&cInfo.handleFlag, devConn->localFlag));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
devConn->remoteBuff = NULL;
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&devConn->remoteFlag, 1));
|
||||
|
||||
struct mscclppIbContext *ibCtx = conn->ibCtx;
|
||||
if (conn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp));
|
||||
}
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, conn->buff, conn->buffSize, &conn->ibMr));
|
||||
cInfo.qpInfo = conn->ibQp->info;
|
||||
cInfo.mrInfo = conn->ibMr->info;
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localFlag, sizeof(int), &conn->ibLocalFlagMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->remoteFlag, sizeof(int), &conn->ibRemoteFlagMr));
|
||||
cInfo.infoQp = conn->ibQp->info;
|
||||
cInfo.infoBuffMr = conn->ibBuffMr->info;
|
||||
cInfo.infoLocalFlagMr = conn->ibLocalFlagMr->info;
|
||||
cInfo.infoRemoteFlagMr = conn->ibRemoteFlagMr->info;
|
||||
}
|
||||
int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend;
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo)));
|
||||
MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo)));
|
||||
}
|
||||
|
||||
// Allocate connection info to be shared with GPU
|
||||
MSCCLPPCHECK(mscclppCudaHostCalloc(&comm->devConns, comm->nConns));
|
||||
|
||||
// Recv info from peers
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn *conn = &comm->conns[i];
|
||||
struct mscclppDevConn *devConn = &comm->devConns[i];
|
||||
|
||||
devConn->tag = conn->tag;
|
||||
devConn->localBuff = conn->buff;
|
||||
devConn->localFlag = conn->flag;
|
||||
|
||||
struct connInfo cInfo;
|
||||
int peer = conn->rankSend == comm->rank ? conn->rankRecv : conn->rankSend;
|
||||
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, peer, conn->tag, &cInfo, sizeof(cInfo)));
|
||||
MSCCLPPCHECK(bootstrapRecv(comm->bootstrap, conn->remoteRank, conn->tag, &cInfo, sizeof(cInfo)));
|
||||
if (conn->transport == mscclppTransportP2P) {
|
||||
CUDACHECK(cudaIpcOpenMemHandle(&devConn->remoteBuff, cInfo.handleBuff, cudaIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&devConn->remoteFlag, cInfo.handleFlag, cudaIpcMemLazyEnablePeerAccess));
|
||||
} else if (conn->transport == mscclppTransportIB) {
|
||||
if (conn->ibQp->rtr(&cInfo.qpInfo) != 0) {
|
||||
if (conn->ibQp->rtr(&cInfo.infoQp) != 0) {
|
||||
WARN("Failed to transition QP to RTR");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
@@ -247,9 +254,9 @@ mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm)
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
conn->ibRemoteMrInfo = cInfo.mrInfo;
|
||||
devConn->remoteBuff = NULL;
|
||||
CUDACHECK(cudaMalloc(&devConn->remoteFlag, sizeof(int)));
|
||||
conn->ibBuffMrInfo = cInfo.infoBuffMr;
|
||||
conn->ibLocalFlagMrInfo = cInfo.infoLocalFlagMr;
|
||||
conn->ibRemoteFlagMrInfo = cInfo.infoRemoteFlagMr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include <sys/syscall.h>
|
||||
#include <map>
|
||||
|
||||
#define MSCCLPP_PROXY_FLAG_SET_BY_RDMA 1
|
||||
|
||||
struct proxyArgs {
|
||||
struct mscclppComm* comm;
|
||||
struct mscclppIbContext* ibCtx;
|
||||
@@ -27,45 +29,47 @@ void* mscclppProxyService(void* _args) {
|
||||
};
|
||||
|
||||
int rank = comm->rank;
|
||||
std::map<int, struct mscclppConn *> recvTagToConn;
|
||||
std::map<int, struct mscclppConn *> sendTagToConn;
|
||||
std::map<struct mscclppConn *, int> sendConnToState;
|
||||
std::map<uint32_t, struct mscclppConn *> qpNumToConn;
|
||||
std::map<volatile uint64_t *, std::pair<int, struct mscclppConn *>> trigToSendStateAndConn;
|
||||
for (int i = 0; i < comm->nConns; ++i) {
|
||||
struct mscclppConn *conn = &comm->conns[i];
|
||||
if (conn->transport != mscclppTransportIB) continue;
|
||||
if (conn->ibCtx != ibCtx) continue;
|
||||
if (conn->rankRecv == rank) {
|
||||
recvTagToConn[conn->tag] = conn;
|
||||
} else if (conn->rankSend == rank) {
|
||||
sendTagToConn[conn->tag] = conn;
|
||||
sendConnToState[conn] = SEND_STATE_INIT;
|
||||
}
|
||||
}
|
||||
// Initial post recv
|
||||
for (auto &pair : recvTagToConn) {
|
||||
struct mscclppConn *conn = pair.second;
|
||||
int tag = pair.first;
|
||||
if (conn->ibQp->postRecv((uint64_t)-tag) != 0) {
|
||||
volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger;
|
||||
trigToSendStateAndConn[tmp].first = SEND_STATE_INIT;
|
||||
trigToSendStateAndConn[tmp].second = conn;
|
||||
qpNumToConn[conn->ibQp->qp->qp_num] = conn;
|
||||
// All connections may read
|
||||
if (conn->ibQp->postRecv(0) != 0) {
|
||||
WARN("postRecv failed: errno %d", errno);
|
||||
}
|
||||
}
|
||||
// TODO(chhwang): run send and recv in different threads for lower latency
|
||||
mscclppTrigger trigger;
|
||||
int wcNum;
|
||||
while (*stop == 0) {
|
||||
// Try send
|
||||
for (auto &pair : sendConnToState) {
|
||||
if (pair.second == SEND_STATE_INPROGRESS) continue;
|
||||
// TODO(chhwang): do we need a thread per flag?
|
||||
struct mscclppConn *conn = pair.first;
|
||||
volatile int *flag = (volatile int *)conn->flag;
|
||||
if (*flag == 0) continue;
|
||||
// TODO(chhwang): one thread per conn
|
||||
for (auto &pair : trigToSendStateAndConn) {
|
||||
if (pair.second.first != SEND_STATE_INIT) continue;
|
||||
trigger.value = *pair.first;
|
||||
if (trigger.value == 0) continue;
|
||||
// Do send
|
||||
conn->ibQp->stageSend(conn->ibMr, &conn->ibRemoteMrInfo, conn->buffSize,
|
||||
(uint64_t)conn->tag, (unsigned int)conn->tag);
|
||||
struct mscclppConn *conn = pair.second.second;
|
||||
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA == 1)
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize,
|
||||
/*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/false);
|
||||
// My local flag is copied to the peer's remote flag
|
||||
conn->ibQp->stageSend(conn->ibLocalFlagMr, &conn->ibRemoteFlagMrInfo, sizeof(int),
|
||||
/*wrId=*/0, /*immData=*/0, /*offset=*/0, /*signaled=*/true);
|
||||
#else
|
||||
conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)trigger.fields.dataSize,
|
||||
/*wrId=*/0, /*immData=*/0, /*offset=*/trigger.fields.dataOffset, /*signaled=*/true);
|
||||
#endif
|
||||
if (conn->ibQp->postSend() != 0) {
|
||||
WARN("postSend failed: errno %d", errno);
|
||||
}
|
||||
pair.second = SEND_STATE_INPROGRESS;
|
||||
pair.second.first = SEND_STATE_INPROGRESS;
|
||||
}
|
||||
|
||||
// Poll completions
|
||||
@@ -74,32 +78,26 @@ void* mscclppProxyService(void* _args) {
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc *wc = &ibCtx->wcs[i];
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
WARN("rank %d wc status %d", rank, wc->status);
|
||||
continue;
|
||||
}
|
||||
if (((int)wc->wr_id) < 0) {
|
||||
// recv
|
||||
auto search = recvTagToConn.find(wc->imm_data);
|
||||
if (search == recvTagToConn.end()) {
|
||||
WARN("unexpected imm_data %d", wc->imm_data);
|
||||
}
|
||||
struct mscclppConn *conn = search->second;
|
||||
if (conn->ibQp->postRecv((uint64_t)-wc->imm_data) != 0) {
|
||||
struct mscclppConn *conn = qpNumToConn[wc->qp_num];
|
||||
if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
// recv completion
|
||||
if (qpNumToConn[wc->qp_num]->ibQp->postRecv(wc->wr_id) != 0) {
|
||||
WARN("postRecv failed: errno %d", errno);
|
||||
}
|
||||
volatile int *flag = (volatile int *)conn->flag;
|
||||
*flag = 1;
|
||||
} else {
|
||||
// send
|
||||
int tag = (int)wc->wr_id;
|
||||
auto search = sendTagToConn.find(tag);
|
||||
if (search == sendTagToConn.end()) {
|
||||
WARN("unexpected tag %d", tag);
|
||||
}
|
||||
struct mscclppConn *conn = search->second;
|
||||
volatile int *flag = (volatile int *)conn->flag;
|
||||
*flag = 0;
|
||||
sendConnToState[conn] = SEND_STATE_INIT;
|
||||
// WARN("send done rank %d", rank);
|
||||
#if (MSCCLPP_PROXY_FLAG_SET_BY_RDMA != 1)
|
||||
// TODO(chhwang): gdc & cpu flush
|
||||
// *((volatile int *)conn->devConn->remoteFlag) = 1;
|
||||
#endif
|
||||
// WARN("rank %d recv completion", rank);
|
||||
} else if (wc->opcode == IBV_WC_RDMA_WRITE) {
|
||||
// send completion
|
||||
volatile uint64_t *tmp = (volatile uint64_t *)conn->devConn->trigger;
|
||||
*tmp = 0;
|
||||
trigToSendStateAndConn[tmp].first = SEND_STATE_INIT;
|
||||
// WARN("rank %d send completion", rank);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,8 +160,7 @@
|
||||
|
||||
struct mscclppConn {
|
||||
mscclppTransport_t transport;
|
||||
int rankSend;
|
||||
int rankRecv;
|
||||
int remoteRank;
|
||||
int tag;
|
||||
void* buff;
|
||||
int buffSize;
|
||||
@@ -169,8 +168,12 @@ struct mscclppConn {
|
||||
struct mscclppDevConn *devConn;
|
||||
struct mscclppIbContext *ibCtx;
|
||||
struct mscclppIbQp *ibQp;
|
||||
struct mscclppIbMr *ibMr;
|
||||
struct mscclppIbMrInfo ibRemoteMrInfo;
|
||||
struct mscclppIbMr *ibBuffMr;
|
||||
struct mscclppIbMr *ibLocalFlagMr;
|
||||
struct mscclppIbMr *ibRemoteFlagMr;
|
||||
struct mscclppIbMrInfo ibBuffMrInfo;
|
||||
struct mscclppIbMrInfo ibLocalFlagMrInfo;
|
||||
struct mscclppIbMrInfo ibRemoteFlagMrInfo;
|
||||
};
|
||||
|
||||
struct mscclppComm {
|
||||
|
||||
@@ -45,8 +45,8 @@ struct mscclppIbQp {
|
||||
|
||||
int rtr(const mscclppIbQpInfo *info);
|
||||
int rts();
|
||||
int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, int size,
|
||||
uint64_t wrId, unsigned int immData, int offset = 0);
|
||||
int stageSend(struct mscclppIbMr *ibMr, const mscclppIbMrInfo *info, uint32_t size,
|
||||
uint64_t wrId, unsigned int immData, uint64_t offset, bool signaled);
|
||||
int postSend();
|
||||
int postRecv(uint64_t wrId);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#if CUDART_VERSION >= 11000
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#define MSCCLPP_MAJOR 0
|
||||
#define MSCCLPP_MINOR 1
|
||||
@@ -16,6 +17,14 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
union alignas(8) mscclppTrigger {
|
||||
uint64_t value;
|
||||
struct {
|
||||
uint64_t dataSize : 32;
|
||||
uint64_t dataOffset : 32;
|
||||
} fields;
|
||||
};
|
||||
|
||||
struct mscclppDevConn {
|
||||
int tag;
|
||||
|
||||
@@ -32,6 +41,8 @@ struct mscclppDevConn {
|
||||
// virtual void pullRmoteFlag();
|
||||
// // localBuff[srcOffset..srcOffset+size-1] <- remoteBuff[dstOffset..dstOffset+size-1]
|
||||
// virtual void pullRemoteBuff(size_t srcOffset, size_t dstOffset, size_t size);
|
||||
|
||||
mscclppTrigger* trigger;
|
||||
};
|
||||
|
||||
typedef struct mscclppComm* mscclppComm_t;
|
||||
@@ -102,8 +113,8 @@ mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int si
|
||||
|
||||
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
|
||||
|
||||
mscclppResult_t mscclppConnect(mscclppComm_t comm, int rankRecv, int rankSend, void *buff, size_t buffSize, int *flag, int tag,
|
||||
mscclppTransport_t transportType, const char *ibDev=NULL);
|
||||
mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, void *buff, size_t buffSize, int *flag,
|
||||
int tag, mscclppTransport_t transportType, const char *ibDev=NULL);
|
||||
|
||||
mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm);
|
||||
|
||||
|
||||
@@ -14,75 +14,51 @@
|
||||
#define CUDACHECK(cmd) do { \
|
||||
cudaError_t err = cmd; \
|
||||
if( err != cudaSuccess ) { \
|
||||
printf("Cuda failure '%s'", cudaGetErrorString(err)); \
|
||||
printf("%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while(false)
|
||||
|
||||
__global__ void kernel(mscclppDevConn_t devConns, int rank, int world_size)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid == 0) {
|
||||
// Get sending data and send flag
|
||||
volatile int *data;
|
||||
for (int i = 0; i < (world_size - 1) * 2; ++i) {
|
||||
mscclppDevConn_t devConn = &devConns[i];
|
||||
int tag = devConn->tag;
|
||||
int rankSend = tag % world_size;
|
||||
if (rankSend == rank) { // I am a sender
|
||||
data = (volatile int *)devConn->localBuff;
|
||||
// We are sending the same data to all peers, so just break here
|
||||
break;
|
||||
}
|
||||
}
|
||||
int warpId = threadIdx.x / 32;
|
||||
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
|
||||
mscclppDevConn_t devConn = &devConns[(remoteRank < rank) ? remoteRank : remoteRank - 1];
|
||||
volatile int *data = (volatile int *)devConn->localBuff;
|
||||
volatile int *localFlag = devConn->localFlag;
|
||||
volatile int *remoteFlag = devConn->remoteFlag;
|
||||
volatile uint64_t *trig = (volatile uint64_t *)devConn->trigger;
|
||||
|
||||
// Set my data
|
||||
*data = rank + 1;
|
||||
if (threadIdx.x == 0) {
|
||||
// Set my data and flag
|
||||
*(data + rank) = rank + 1;
|
||||
__threadfence_system();
|
||||
*localFlag = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Set send flags to inform all peers that the data is ready
|
||||
for (int i = 0; i < (world_size - 1) * 2; ++i) {
|
||||
mscclppDevConn_t devConn = &devConns[i];
|
||||
int tag = devConn->tag;
|
||||
int rankSend = tag % world_size;
|
||||
if (rankSend == rank) { // I am a sender
|
||||
*((volatile int *)devConn->localFlag) = 1;
|
||||
}
|
||||
}
|
||||
// Each warp receives data from different ranks
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
if (devConn->remoteBuff == NULL) { // IB
|
||||
// Trigger sending data and flag
|
||||
uint64_t dataOffset = rank * sizeof(int);
|
||||
uint64_t dataSize = sizeof(int);
|
||||
*trig = (dataOffset << 32) + dataSize;
|
||||
|
||||
// Read data from all other peers
|
||||
for (int i = 0; i < (world_size - 1) * 2; ++i) {
|
||||
mscclppDevConn_t devConn = &devConns[i];
|
||||
int tag = devConn->tag;
|
||||
int rankSend = tag % world_size;
|
||||
int rankRecv = tag / world_size;
|
||||
if (rankRecv == rank) { // I am a receiver
|
||||
if (devConn->remoteBuff == NULL) { // IB
|
||||
volatile int *localFlag = (volatile int *)devConn->localFlag;
|
||||
// Wait until the proxy have sent my data and flag
|
||||
while (*trig != 0) {}
|
||||
|
||||
// Wait until the data comes in via proxy
|
||||
while (*localFlag != 1) {}
|
||||
} else { // P2P
|
||||
volatile int *remoteData = (volatile int *)devConn->remoteBuff;
|
||||
volatile int *remoteFlag = (volatile int *)devConn->remoteFlag;
|
||||
// Wait for receiving data from remote rank
|
||||
while (*remoteFlag != 1) {}
|
||||
} else { // P2P
|
||||
// Directly read data
|
||||
volatile int *remoteData = (volatile int *)devConn->remoteBuff;
|
||||
|
||||
// Wait until the remote data is set
|
||||
while (*remoteFlag != 1) {}
|
||||
// Wait until the remote data is set
|
||||
while (*remoteFlag != 1) {}
|
||||
|
||||
// Read remote data
|
||||
data[rankSend] = remoteData[rankSend];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until the proxy have sent my data to all peers
|
||||
for (int i = 0; i < (world_size - 1) * 2; ++i) {
|
||||
mscclppDevConn_t devConn = &devConns[i];
|
||||
int tag = devConn->tag;
|
||||
int rankSend = tag % world_size;
|
||||
if (rankSend == rank) { // I am a sender
|
||||
volatile int *flag = (volatile int *)devConn->localFlag;
|
||||
while (*flag == 1) {}
|
||||
}
|
||||
// Read remote data
|
||||
data[remoteRank] = remoteData[remoteRank];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,6 +109,8 @@ int main(int argc, const char *argv[])
|
||||
int rank = atoi(argv[2]);
|
||||
int world_size = atoi(argv[3]);
|
||||
#endif
|
||||
int localRank = rankToLocalRank(rank);
|
||||
int thisNode = rankToNode(rank);
|
||||
|
||||
mscclppComm_t comm;
|
||||
mscclppResult_t res = mscclppCommInitRank(&comm, world_size, rank, ip_port);
|
||||
@@ -141,64 +119,33 @@ int main(int argc, const char *argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUDACHECK(cudaSetDevice(localRank));
|
||||
|
||||
int *data_d;
|
||||
int *send_flags_d;
|
||||
int *recv_flags_d;
|
||||
int *flag_d;
|
||||
CUDACHECK(cudaMalloc(&data_d, sizeof(int) * world_size));
|
||||
CUDACHECK(cudaHostAlloc(&send_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped));
|
||||
CUDACHECK(cudaHostAlloc(&recv_flags_d, sizeof(int) * (world_size - 1), cudaHostAllocMapped));
|
||||
|
||||
CUDACHECK(cudaMalloc(&flag_d, sizeof(int)));
|
||||
CUDACHECK(cudaMemset(data_d, 0, sizeof(int) * world_size));
|
||||
// CUDACHECK(cudaMemcpy(data_d, tmp, sizeof(int) * 2, cudaMemcpyHostToDevice));
|
||||
// printf("rank %d CPU: setting data at %p\n", rank, data_d + rank);
|
||||
memset(send_flags_d, 0, sizeof(int) * (world_size - 1));
|
||||
memset(recv_flags_d, 0, sizeof(int) * (world_size - 1));
|
||||
CUDACHECK(cudaMemset(flag_d, 0, sizeof(int)));
|
||||
|
||||
int localRank = rankToLocalRank(rank);
|
||||
int thisNode = rankToNode(rank);
|
||||
std::string ibDev = "mlx5_ib" + std::to_string(localRank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(localRank);
|
||||
|
||||
// Read from all other ranks
|
||||
int idx = 0;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
int tag = rank * world_size + r;
|
||||
mscclppTransport_t transportType = mscclppTransportIB;
|
||||
const char *ibDev = ibDevStr.c_str();
|
||||
#if (TEST_CONN_TYPE == 0) // P2P+IB
|
||||
int node = rankToNode(r);
|
||||
if (node == thisNode) {
|
||||
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportP2P);
|
||||
} else {
|
||||
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
|
||||
if (rankToNode(r) == thisNode) {
|
||||
transportType = mscclppTransportP2P;
|
||||
ibDev = NULL;
|
||||
}
|
||||
#else // (TEST_CONN_TYPE == 1) // IB-Only
|
||||
res = mscclppConnect(comm, rank, r, data_d + r, sizeof(int), recv_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
|
||||
#endif
|
||||
// Connect with all other ranks
|
||||
res = mscclppConnect(comm, r, data_d, sizeof(int) * world_size, flag_d, 0, transportType, ibDev);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("mscclppConnect failed\n");
|
||||
return -1;
|
||||
}
|
||||
++idx;
|
||||
}
|
||||
// Let others read from me
|
||||
idx = 0;
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
int tag = r * world_size + rank;
|
||||
#if (TEST_CONN_TYPE == 0) // P2P+IB
|
||||
int node = rankToNode(r);
|
||||
if (node == thisNode) {
|
||||
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportP2P);
|
||||
} else {
|
||||
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
|
||||
}
|
||||
#else // (TEST_CONN_TYPE == 1) // IB-Only
|
||||
res = mscclppConnect(comm, r, rank, data_d + rank, sizeof(int), send_flags_d + idx, tag, mscclppTransportIB, ibDev.c_str());
|
||||
#endif
|
||||
if (res != mscclppSuccess) {
|
||||
printf("mscclppConnect failed\n");
|
||||
return -1;
|
||||
}
|
||||
++idx;
|
||||
}
|
||||
|
||||
res = mscclppConnectionSetup(comm);
|
||||
@@ -216,7 +163,7 @@ int main(int argc, const char *argv[])
|
||||
mscclppDevConn_t devConns;
|
||||
mscclppGetDevConns(comm, &devConns);
|
||||
|
||||
kernel<<<1, 1>>>(devConns, rank, world_size);
|
||||
kernel<<<1, 32 * (world_size - 1)>>>(devConns, rank, world_size);
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
|
||||
res = mscclppProxyStop(comm);
|
||||
|
||||
Reference in New Issue
Block a user