From 63a5be695355bc816bc618d343fc7b711ff628a6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 12 Apr 2023 09:20:05 +0000 Subject: [PATCH] Move ibQp to mscclppHostIBConn --- src/include/comm.h | 1 - src/init.cc | 32 ++++++++++++++++---------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/include/comm.h b/src/include/comm.h index 366659d5..04e21b56 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -28,7 +28,6 @@ struct mscclppConn struct mscclppHostConn* hostConn; struct mscclppIbContext* ibCtx; - struct mscclppIbQp* ibQp; #if defined(ENABLE_NPKIT) std::vector npkitUsedReqIds; std::vector npkitFreeReqIds; diff --git a/src/init.cc b/src/init.cc index f4f47487..f04f14fa 100644 --- a/src/init.cc +++ b/src/init.cc @@ -344,11 +344,11 @@ struct mscclppHostIBConn : mscclppHostConn{ mscclppHostIBConn(mscclppConn* conn) : conn(conn) {} void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize){ - conn->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize, + this->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize, /*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false); - int ret = conn->ibQp->postSend(); + int ret = this->ibQp->postSend(); if (ret != 0) { // Return value is errno. WARN("data postSend failed: errno %d", ret); @@ -357,9 +357,9 @@ struct mscclppHostIBConn : mscclppHostConn{ } void signal(){ // My local device flag is copied to the remote's proxy flag - conn->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t), + this->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t), /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true); - int ret = conn->ibQp->postSend(); + int ret = this->ibQp->postSend(); if (ret != 0) { WARN("flag postSend failed: errno %d", ret); } @@ -369,18 +369,18 @@ struct mscclppHostIBConn : mscclppHostConn{ void flush(){ bool isWaiting = true; while (isWaiting) { - int wcNum = conn->ibQp->pollCq(); + int wcNum = this->ibQp->pollCq(); if (wcNum < 0) { WARN("pollCq failed: errno %d", errno); continue; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &conn->ibQp->wcs[i]; + struct ibv_wc* wc = &this->ibQp->wcs[i]; if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); continue; } - if (wc->qp_num != conn->ibQp->qp->qp_num) { + if (wc->qp_num != this->ibQp->qp->qp_num) { WARN("got wc of unknown qp_num %d", wc->qp_num); continue; } @@ -394,6 +394,7 @@ struct mscclppHostIBConn : mscclppHostConn{ } mscclppConn* conn; + struct mscclppIbQp* ibQp; struct mscclppIbMr* ibBuffMr; struct mscclppIbMr* ibSignalEpochIdMr; struct mscclppIbMrInfo ibBuffMrInfo; @@ -429,7 +430,6 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i conn->buffSize = buffSize; conn->ibCtx = NULL; - conn->ibQp = NULL; int ibDevIdx = -1; if (transportType == mscclppTransportIB) { // Check if an IB context exists @@ -593,13 +593,13 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output devConn->remoteSignalEpochId = NULL; struct mscclppIbContext* ibCtx = conn->ibCtx; - if (conn->ibQp == NULL) { - MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp)); + if (hostConn->ibQp == NULL) { + MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp)); } MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr)); MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId, sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr)); - connInfo->infoQp = conn->ibQp->info; + connInfo->infoQp = hostConn->ibQp->info; connInfo->infoBuffMr = hostConn->ibBuffMr->info; connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info; return mscclppSuccess; @@ -611,15 +611,15 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, WARN("ipcHandles or connection cannot be null"); return mscclppInternalError; } - if (conn->ibQp->rtr(&connInfo->infoQp) != 0) { + struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; + if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) { WARN("Failed to transition QP to RTR"); return mscclppInvalidUsage; } - if (conn->ibQp->rts() != 0) { + if (hostConn->ibQp->rts() != 0) { WARN("Failed to transition QP to RTS"); return mscclppInvalidUsage; } - struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; hostConn->ibBuffMrInfo = connInfo->infoBuffMr; hostConn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr; return mscclppSuccess; @@ -719,11 +719,11 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream)); } else { struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; - conn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size, + hostConn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); - if ((ret = conn->ibQp->postSend()) != 0) { + if ((ret = hostConn->ibQp->postSend()) != 0) { // Return value is errno. WARN("data postSend failed: errno %d", ret); }