From bc729cd48161d93763e0a89f086e8e27abff46e6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 12 Apr 2023 09:05:42 +0000 Subject: [PATCH] Move MRs / MR infos to mscclppHostIBConn & cleanup --- src/include/comm.h | 5 +-- src/include/mscclpp.h | 1 + src/init.cc | 88 ++++++++++++++++++++++++++++++++++++------- src/proxy.cc | 66 ++------------------------------ 4 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/include/comm.h b/src/include/comm.h index c4927143..366659d5 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -20,6 +20,7 @@ struct mscclppConn { + int connId; mscclppTransport_t transport; int remoteRank; uint64_t buffSize; @@ -28,10 +29,6 @@ struct mscclppConn struct mscclppIbContext* ibCtx; struct mscclppIbQp* ibQp; - struct mscclppIbMr* ibBuffMr; - struct mscclppIbMr* ibSignalEpochIdMr; - struct mscclppIbMrInfo ibBuffMrInfo; - struct mscclppIbMrInfo ibSignalEpochIdMrInfo; #if defined(ENABLE_NPKIT) std::vector npkitUsedReqIds; std::vector npkitFreeReqIds; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 11e0cbf6..b7db058b 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -188,6 +188,7 @@ struct mscclppDevConn // Host interface for mscclppDevCon functionality struct mscclppHostConn{ + virtual ~mscclppHostConn() = default; virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0; virtual void signal() = 0; virtual void wait() = 0; diff --git a/src/init.cc b/src/init.cc index aaf655ea..f4f47487 100644 --- a/src/init.cc +++ b/src/init.cc @@ -267,6 +267,53 @@ MSCCLPP_API mscclppResult_t mscclppGetAllDeviceConnections(mscclppComm_t comm, m return mscclppSuccess; } +#if defined(ENABLE_NPKIT) + +static void npkitInitReqIds(struct mscclppComm* comm) +{ + for (int i = 0; i < comm->nConns; i++) { + struct mscclppConn* conn = &comm->conns[i]; + conn->npkitUsedReqIds.resize(0); + conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS); + for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) { + conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1; + } + } +} + +static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size) +{ + uint64_t reqId = 0; + if (conn->npkitFreeReqIds.size() == 0) { + reqId = conn->npkitUsedReqIds.size(); + } else { + reqId = conn->npkitFreeReqIds.back(); + conn->npkitFreeReqIds.pop_back(); + } + conn->npkitUsedReqIds.push_back(reqId); + NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId); +} + +static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type) +{ + while (conn->npkitUsedReqIds.size()) { + uint64_t reqId = conn->npkitUsedReqIds.back(); + NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), conn->connId); + conn->npkitFreeReqIds.push_back(reqId); + conn->npkitUsedReqIds.pop_back(); + } +} + +#else + +#define npkitInitReqIds(comm) + +#define npkitCollectEntryEvent(conn, type, size) + +#define npkitCollectExitEvents(conn, type) + +#endif + struct mscclppHostP2PConn : mscclppHostConn{ mscclppHostP2PConn(mscclppConn* _conn, cudaStream_t _stream) : conn(_conn), p2pStream(_stream){} @@ -275,15 +322,18 @@ struct mscclppHostP2PConn : mscclppHostConn{ void* srcBuff = (void*)((char*)conn->devConn->localBuff + srcDataOffset); void* dstBuff = (void*)((char*)conn->devConn->remoteBuff + dstDataOffset); CUDACHECKNORET(cudaMemcpyAsync(dstBuff, srcBuff, dataSize, cudaMemcpyDeviceToDevice, p2pStream)); + npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize); } void signal(){ CUDACHECKNORET(cudaMemcpyAsync(&conn->devConn->remoteSignalEpochId->proxy, &(conn->devConn->localSignalEpochId->device), sizeof(uint64_t), cudaMemcpyDeviceToDevice, p2pStream)); + npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t)); } void wait(){} void flush(){ CUDACHECKNORET(cudaStreamSynchronize(p2pStream)); + npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); } mscclppConn* conn; @@ -294,7 +344,7 @@ struct mscclppHostIBConn : mscclppHostConn{ mscclppHostIBConn(mscclppConn* conn) : conn(conn) {} void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize){ - conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)dataSize, + conn->ibQp->stageSend(this->ibBuffMr, &this->ibBuffMrInfo, (uint32_t)dataSize, /*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false); @@ -303,15 +353,17 @@ struct mscclppHostIBConn : mscclppHostConn{ // Return value is errno. WARN("data postSend failed: errno %d", ret); } + npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize); } void signal(){ // My local device flag is copied to the remote's proxy flag - conn->ibQp->stageSend(conn->ibSignalEpochIdMr, &conn->ibSignalEpochIdMrInfo, sizeof(uint64_t), + conn->ibQp->stageSend(this->ibSignalEpochIdMr, &this->ibSignalEpochIdMrInfo, sizeof(uint64_t), /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true); int ret = conn->ibQp->postSend(); if (ret != 0) { WARN("flag postSend failed: errno %d", ret); } + npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t)); } void wait(){} void flush(){ @@ -338,9 +390,14 @@ struct mscclppHostIBConn : mscclppHostConn{ } } } + npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } mscclppConn* conn; + struct mscclppIbMr* ibBuffMr; + struct mscclppIbMr* ibSignalEpochIdMr; + struct mscclppIbMrInfo ibBuffMrInfo; + struct mscclppIbMrInfo ibSignalEpochIdMrInfo; }; @@ -365,7 +422,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i WARN("Too many connections made"); return mscclppInternalError; } - struct mscclppConn* conn = &comm->conns[comm->nConns]; + int connId = comm->nConns; + struct mscclppConn* conn = &comm->conns[connId]; + conn->connId = connId; conn->transport = transportType; conn->buffSize = buffSize; @@ -463,8 +522,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i conn->hostConn = new mscclppHostP2PConn(conn, proxyState->p2pStream); } - - struct mscclppDevConn* devConn = &comm->devConns[comm->nConns]; + struct mscclppDevConn* devConn = &comm->devConns[connId]; conn->devConn = devConn; conn->devConn->localBuff = localBuff; @@ -472,7 +530,7 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1)); conn->devConn->remoteRank = remoteRank; conn->devConn->tag = tag; - conn->devConn->fifo.connId = comm->nConns; + conn->devConn->fifo.connId = connId; #if defined(MSCCLPP_USE_GDRCOPY) conn->devConn->fifo.triggerFifo = proxyState->fifo.triggerFifoDev; #else @@ -530,6 +588,7 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output return mscclppInternalError; } struct mscclppDevConn* devConn = conn->devConn; + struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; devConn->remoteBuff = NULL; devConn->remoteSignalEpochId = NULL; @@ -537,12 +596,12 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output if (conn->ibQp == NULL) { MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp)); } - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr)); + MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &hostConn->ibBuffMr)); MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId, - sizeof(struct mscclppDevConnSignalEpochId), &conn->ibSignalEpochIdMr)); + sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibSignalEpochIdMr)); connInfo->infoQp = conn->ibQp->info; - connInfo->infoBuffMr = conn->ibBuffMr->info; - connInfo->infoSignalEpochIdMr = conn->ibSignalEpochIdMr->info; + connInfo->infoBuffMr = hostConn->ibBuffMr->info; + connInfo->infoSignalEpochIdMr = hostConn->ibSignalEpochIdMr->info; return mscclppSuccess; } @@ -560,8 +619,9 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, WARN("Failed to transition QP to RTS"); return mscclppInvalidUsage; } - conn->ibBuffMrInfo = connInfo->infoBuffMr; - conn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr; + struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; + hostConn->ibBuffMrInfo = connInfo->infoBuffMr; + hostConn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr; return mscclppSuccess; } @@ -658,7 +718,8 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc void* dstBuff = regMem->p2p[i].remoteBuff; CUDACHECK(cudaMemcpyAsync(dstBuff, srcBuff, size, cudaMemcpyDeviceToDevice, (cudaStream_t)stream)); } else { - conn->ibQp->stageSend(conn->ibBuffMr, &conn->ibBuffMrInfo, (uint32_t)size, + struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; + conn->ibQp->stageSend(hostConn->ibBuffMr, &hostConn->ibBuffMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); @@ -678,6 +739,7 @@ MSCCLPP_API mscclppResult_t mscclppRegisteredBufferWrite(mscclppComm_t comm, msc MSCCLPP_API mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm) { + npkitInitReqIds(comm); MSCCLPPCHECK(mscclppProxyCreate(comm)); return mscclppSuccess; } diff --git a/src/proxy.cc b/src/proxy.cc index d7e291e2..044316d7 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -10,8 +10,6 @@ #include #include -#include "npkit/npkit.h" - #define MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD 100 #define PROXYCUDACHECK(cmd) \ @@ -40,53 +38,6 @@ struct proxyArgs struct mscclppProxyState* proxyState; }; -#if defined(ENABLE_NPKIT) - -static void npkitInitReqIds(struct mscclppComm* comm) -{ - for (int i = 0; i < comm->nConns; i++) { - struct mscclppConn* conn = &comm->conns[i]; - conn->npkitUsedReqIds.resize(0); - conn->npkitFreeReqIds.resize(MSCCLPP_IB_MAX_SENDS); - for (uint64_t j = 0; j < MSCCLPP_IB_MAX_SENDS; j++) { - conn->npkitFreeReqIds[j] = MSCCLPP_IB_MAX_SENDS - j - 1; - } - } -} - -static void npkitCollectEntryEvent(struct mscclppConn* conn, uint8_t type, uint32_t size, int channelId) -{ - uint64_t reqId = 0; - if (conn->npkitFreeReqIds.size() == 0) { - reqId = conn->npkitUsedReqIds.size(); - } else { - reqId = conn->npkitFreeReqIds.back(); - conn->npkitFreeReqIds.pop_back(); - } - conn->npkitUsedReqIds.push_back(reqId); - NpKit::CollectCpuEvent(type, size, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId); -} - -static void npkitCollectExitEvents(struct mscclppConn* conn, uint8_t type, int channelId) -{ - while (conn->npkitUsedReqIds.size()) { - uint64_t reqId = conn->npkitUsedReqIds.back(); - NpKit::CollectCpuEvent(type, 0, (uint32_t)reqId, NpKit::GetCpuTimestamp(), channelId); - conn->npkitFreeReqIds.push_back(reqId); - conn->npkitUsedReqIds.pop_back(); - } -} - -#else - -#define npkitInitReqIds(comm) - -#define npkitCollectEntryEvent(conn, type, size, channelId) - -#define npkitCollectExitEvents(conn, type, channelId) - -#endif - mscclppResult_t mscclppProxyFifo::create() { MSCCLPPCHECK(mscclppCudaCalloc(&this->fifoHead, 1)); @@ -150,29 +101,20 @@ mscclppResult_t mscclppProxyFifo::flushTail(bool sync) return mscclppSuccess; } -void processTrigger(const mscclppTrigger trigger, mscclppConn* conn, mscclppProxyState* proxyState){ - mscclppIbContext* ibCtx = proxyState->ibContext; - bool isP2pProxy = (ibCtx == nullptr); - +static void processTrigger(const mscclppTrigger trigger, mscclppConn* conn) +{ // Iterate over what send is needed if (trigger.fields.type & mscclppData) { conn->hostConn->put(trigger.fields.dstDataOffset, trigger.fields.srcDataOffset, trigger.fields.dataSize); - - npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_DMA_SEND_DATA_ENTRY : NPKIT_EVENT_IB_SEND_DATA_ENTRY, - (uint32_t)trigger.fields.dataSize, trigger.fields.connId); } if (trigger.fields.type & mscclppFlag) { conn->hostConn->signal(); - - npkitCollectEntryEvent(conn, isP2pProxy ? NPKIT_EVENT_P2P_SEND_FLAG_ENTRY : NPKIT_EVENT_IB_SEND_FLAG_ENTRY, - (uint32_t)sizeof(uint64_t), trigger.fields.connId); } // Wait for completion if (trigger.fields.type & mscclppSync) { conn->hostConn->flush(); - npkitCollectExitEvents(conn, isP2pProxy? NPKIT_EVENT_DMA_SEND_EXIT : NPKIT_EVENT_IB_SEND_EXIT, trigger.fields.connId); } } @@ -191,8 +133,6 @@ void* mscclppProxyService(void* _args) volatile mscclppProxyRunState_t* run = &proxyState->run; mscclppTrigger trigger; - npkitInitReqIds(comm); - int runCnt = MSCCLPP_PROXY_RUN_STATE_CHECK_PERIOD; uint64_t flushCnt = 0; for (;;) { @@ -209,7 +149,7 @@ void* mscclppProxyService(void* _args) } mscclppConn* conn = &comm->conns[trigger.fields.connId]; - processTrigger(trigger, conn, proxyState); + processTrigger(trigger, conn); // Send completion: reset only the high 64 bits PROXYMSCCLPPCHECK(fifo->pop());