mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Move ibQp to mscclppHostIBConn
This commit is contained in:
@@ -28,7 +28,6 @@ struct mscclppConn
|
||||
struct mscclppHostConn* hostConn;
|
||||
|
||||
struct mscclppIbContext* ibCtx;
|
||||
struct mscclppIbQp* ibQp;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
std::vector<uint64_t> npkitUsedReqIds;
|
||||
std::vector<uint64_t> npkitFreeReqIds;
|
||||
|
||||
32
src/init.cc
32
src/init.cc
@@ -344,11 +344,11 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
mscclppHostIBConn(mscclppConn* conn) : conn(conn) {}
|
||||
|
||||
void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize){
|
||||
conn->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize,
|
||||
this->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize,
|
||||
/*wrId=*/0, /*srcOffset=*/srcDataOffset,
|
||||
/*dstOffset=*/dstDataOffset,
|
||||
/*signaled=*/false);
|
||||
int ret = conn->ibQp->postSend();
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
@@ -357,9 +357,9 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
}
|
||||
void signal(){
|
||||
// My local device flag is copied to the remote's proxy flag
|
||||
conn->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t),
|
||||
this->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t),
|
||||
/*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true);
|
||||
int ret = conn->ibQp->postSend();
|
||||
int ret = this->ibQp->postSend();
|
||||
if (ret != 0) {
|
||||
WARN("flag postSend failed: errno %d", ret);
|
||||
}
|
||||
@@ -369,18 +369,18 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
void flush(){
|
||||
bool isWaiting = true;
|
||||
while (isWaiting) {
|
||||
int wcNum = conn->ibQp->pollCq();
|
||||
int wcNum = this->ibQp->pollCq();
|
||||
if (wcNum < 0) {
|
||||
WARN("pollCq failed: errno %d", errno);
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < wcNum; ++i) {
|
||||
struct ibv_wc* wc = &conn->ibQp->wcs[i];
|
||||
struct ibv_wc* wc = &this->ibQp->wcs[i];
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
WARN("wc status %d", wc->status);
|
||||
continue;
|
||||
}
|
||||
if (wc->qp_num != conn->ibQp->qp->qp_num) {
|
||||
if (wc->qp_num != this->ibQp->qp->qp_num) {
|
||||
WARN("got wc of unknown qp_num %d", wc->qp_num);
|
||||
continue;
|
||||
}
|
||||
@@ -394,6 +394,7 @@ struct mscclppHostIBConn : mscclppHostConn{
|
||||
}
|
||||
|
||||
mscclppConn* conn;
|
||||
struct mscclppIbQp* ibQp;
|
||||
struct mscclppIbMr* ibBuffMr;
|
||||
struct mscclppIbMr* ibSignalEpochIdMr;
|
||||
struct mscclppIbMrInfo ibBuffMrInfo;
|
||||
@@ -429,7 +430,6 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i
|
||||
conn->buffSize = buffSize;
|
||||
|
||||
conn->ibCtx = NULL;
|
||||
conn->ibQp = NULL;
|
||||
int ibDevIdx = -1;
|
||||
if (transportType == mscclppTransportIB) {
|
||||
// Check if an IB context exists
|
||||
@@ -593,13 +593,13 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output
|
||||
devConn->remoteSignalEpochId = NULL;
|
||||
|
||||
struct mscclppIbContext* ibCtx = conn->ibCtx;
|
||||
if (conn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp));
|
||||
if (hostConn->ibQp == NULL) {
|
||||
MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp));
|
||||
}
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr));
|
||||
MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId,
|
||||
sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr));
|
||||
connInfo->infoQp = conn->ibQp->info;
|
||||
connInfo->infoQp = hostConn->ibQp->info;
|
||||
connInfo->infoBuffMr = hostConn->ibBuffMr->info;
|
||||
connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info;
|
||||
return mscclppSuccess;
|
||||
@@ -611,15 +611,15 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/,
|
||||
WARN("ipcHandles or connection cannot be null");
|
||||
return mscclppInternalError;
|
||||
}
|
||||
if (conn->ibQp->rtr(&connInfo->infoQp) != 0) {
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) {
|
||||
WARN("Failed to transition QP to RTR");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
if (conn->ibQp->rts() != 0) {
|
||||
if (hostConn->ibQp->rts() != 0) {
|
||||
WARN("Failed to transition QP to RTS");
|
||||
return mscclppInvalidUsage;
|
||||
}
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
hostConn->ibBuffMrInfo = connInfo->infoBuffMr;
|
||||
hostConn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr;
|
||||
return mscclppSuccess;
|
||||
@@ -719,11 +719,11 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc
|
||||
CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream));
|
||||
} else {
|
||||
struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn;
|
||||
conn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size,
|
||||
hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size,
|
||||
/*wrId=*/0, /*srcOffset=*/srcOffset,
|
||||
/*dstOffset=*/dstOffset,
|
||||
/*signaled=*/false);
|
||||
if ((ret = conn->ibQp->postSend()) != 0) {
|
||||
if ((ret = hostConn->ibQp->postSend()) != 0) {
|
||||
// Return value is errno.
|
||||
WARN("data postSend failed: errno %d", ret);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user