diff --git a/src/include/comm.h b/src/include/comm.h index f1e3ed47..56e24e2d 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -36,10 +36,8 @@ struct mscclppConn struct mscclppIbQp* ibQp; struct mscclppIbMr* ibBuffMr; struct mscclppIbMr* ibSignalEpochIdMr; - struct mscclppIbMr* ibProxySignalEpochIdMr; struct mscclppIbMrInfo ibBuffMrInfo; struct mscclppIbMrInfo ibSignalEpochIdMrInfo; - struct mscclppIbMrInfo ibProxySignalEpochIdMrInfo; #if defined(ENABLE_NPKIT) std::vector npkitUsedReqIds; std::vector npkitFreeReqIds; diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index f7261410..b620153c 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -20,6 +20,16 @@ extern "C" { #endif +struct alignas(16) mscclppDevConnSignalEpochId +{ + // every signal(), increaments this and either: + // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy + // 2) gpu thread directly writes it to remoteSignalEpochId->device + uint64_t device; + // signal() function triggers the cpu proxy thread to write to it + uint64_t proxy; +}; + /*************************************************************************************************************** * A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand. * The communication API is one-sided meaning that for every single data transfer, only one side @@ -140,14 +150,13 @@ struct mscclppDevConn __forceinline__ __device__ void wait() { (*waitEpochId) += 1; - // printf("%llu %llu %llu\n", *(volatile uint64_t*)proxySignalEpochId, (*waitEpochId), *(volatile uint64_t*)signalEpochId); - while (*(volatile uint64_t*)proxySignalEpochId < (*waitEpochId)) + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) ; } __forceinline__ __device__ void epochIncrement() { - *(volatile uint64_t*)signalEpochId += 1; + *(volatile uint64_t*)&(localSignalEpochId->device) += 1; } #endif @@ -156,22 +165,19 @@ struct mscclppDevConn // my local buffer void* localBuff; - // every signal(), increaments this and either: - // 1) proxy thread pushes it to the remote peer's proxySignalEpochId - // 2) gpu thread directly writes it to remoteSignalEpochId - uint64_t* signalEpochId; + + struct mscclppDevConnSignalEpochId* localSignalEpochId; + // used by the signal() function directly from gpu + struct mscclppDevConnSignalEpochId* remoteSignalEpochId; + // every wait(), increaments this and then the gpu waits for either: - // 1) proxySignalEpochId to be >= this in case of a proxy thread - // 2) remoteSignalEpochId to be >= this in case of a gpu thread + // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread + // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread uint64_t* waitEpochId; // my remote peer's buffer. only non-NULL with gpu's direct access // gpu can directly write into it void* remoteBuff; - // used by the signal() function directly from gpu - uint64_t* remoteSignalEpochId; - // signal() function triggers the cpu proxy thread to write to it - uint64_t* proxySignalEpochId; // this is a concurrent fifo which is multiple threads from the device // can produce for and the sole proxy thread consumes it. diff --git a/src/init.cc b/src/init.cc index 0cff496a..08302f6e 100644 --- a/src/init.cc +++ b/src/init.cc @@ -181,11 +181,6 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) if (comm == NULL) return mscclppSuccess; - for (int i = 0; i < comm->nConns; ++i) { - struct mscclppConn* conn = &comm->conns[i]; - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->proxySignalEpochId)); - } - for (int i = 0; i < MSCCLPP_PROXY_MAX_NUM; ++i) { struct mscclppProxyState* proxyState = comm->proxyState[i]; if (proxyState) { @@ -216,7 +211,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) for (int i = 0; i < comm->nConns; i++) { struct mscclppConn* conn = &comm->conns[i]; if (conn) { - MSCCLPPCHECK(mscclppCudaFree(conn->devConn->signalEpochId)); + MSCCLPPCHECK(mscclppCudaFree(conn->devConn->localSignalEpochId)); MSCCLPPCHECK(mscclppCudaFree(conn->devConn->waitEpochId)); } } @@ -419,7 +414,7 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void conn->devConn = devConn; conn->devConn->localBuff = localBuff; - MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->signalEpochId, 1)); + MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->localSignalEpochId, 1)); MSCCLPPCHECK(mscclppCudaCalloc(&conn->devConn->waitEpochId, 1)); conn->devConn->remoteRank = remoteRank; conn->devConn->tag = tag; @@ -444,11 +439,9 @@ struct connInfo { cudaIpcMemHandle_t handleBuff; cudaIpcMemHandle_t handleSignalEpochId; - cudaIpcMemHandle_t handleProxySignalEpochId; mscclppIbQpInfo infoQp; mscclppIbMrInfo infoBuffMr; mscclppIbMrInfo infoSignalEpochIdMr; - mscclppIbMrInfo infoProxySignalEpochIdMr; }; mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*output*/, struct mscclppConn* conn /*input*/) @@ -458,10 +451,8 @@ mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*outpu return mscclppInternalError; } struct mscclppDevConn* devConn = conn->devConn; - MSCCLPPCHECK(mscclppCudaCalloc(&devConn->proxySignalEpochId, 1)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleProxySignalEpochId, devConn->proxySignalEpochId)); CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleBuff, devConn->localBuff)); - CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleSignalEpochId, devConn->signalEpochId)); + CUDACHECK(cudaIpcGetMemHandle(&connInfo->handleSignalEpochId, devConn->localSignalEpochId)); return mscclppSuccess; } @@ -475,8 +466,7 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/ cudaIpcOpenMemHandle((void**)&conn->devConn->remoteBuff, connInfo->handleBuff, cudaIpcMemLazyEnablePeerAccess)); CUDACHECK( cudaIpcOpenMemHandle((void**)&conn->devConn->remoteSignalEpochId, connInfo->handleSignalEpochId, cudaIpcMemLazyEnablePeerAccess)); - CUDACHECK( - cudaIpcOpenMemHandle((void**)&conn->remoteProxyFlag, connInfo->handleProxySignalEpochId, cudaIpcMemLazyEnablePeerAccess)); + conn->remoteProxyFlag = &(conn->devConn->remoteSignalEpochId->proxy); return mscclppSuccess; } @@ -489,20 +479,16 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output struct mscclppDevConn* devConn = conn->devConn; devConn->remoteBuff = NULL; devConn->remoteSignalEpochId = NULL; - MSCCLPPCHECK(mscclppCudaCalloc(&devConn->proxySignalEpochId, 1)); struct mscclppIbContext* ibCtx = conn->ibCtx; if (conn->ibQp == NULL) { MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &conn->ibQp)); } - // TODO(chhwang): can we register only one MR for the following three? MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localBuff, conn->buffSize, &conn->ibBuffMr)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->signalEpochId, sizeof(uint64_t), &conn->ibSignalEpochIdMr)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->proxySignalEpochId, sizeof(uint64_t), &conn->ibProxySignalEpochIdMr)); + MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, devConn->localSignalEpochId, sizeof(struct mscclppDevConnSignalEpochId), &conn->ibSignalEpochIdMr)); connInfo->infoQp = conn->ibQp->info; connInfo->infoBuffMr = conn->ibBuffMr->info; connInfo->infoSignalEpochIdMr = conn->ibSignalEpochIdMr->info; - connInfo->infoProxySignalEpochIdMr = conn->ibProxySignalEpochIdMr->info; return mscclppSuccess; } @@ -522,7 +508,6 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, } conn->ibBuffMrInfo = connInfo->infoBuffMr; conn->ibSignalEpochIdMrInfo = connInfo->infoSignalEpochIdMr; - conn->ibProxySignalEpochIdMrInfo = connInfo->infoProxySignalEpochIdMr; return mscclppSuccess; } diff --git a/src/proxy.cc b/src/proxy.cc index cc0ae870..6545c855 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -162,13 +162,13 @@ void* mscclppProxyService(void* _args) } if (trigger.fields.type & mscclppFlag) { if (isP2pProxy) { - PROXYCUDACHECK(cudaMemcpyAsync(conn->remoteProxyFlag, conn->devConn->signalEpochId, sizeof(uint64_t), + PROXYCUDACHECK(cudaMemcpyAsync(conn->remoteProxyFlag, &(conn->devConn->localSignalEpochId->device), sizeof(uint64_t), cudaMemcpyDeviceToDevice, p2pStream)); npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_ENTRY, (uint32_t)sizeof(uint64_t), trigger.fields.connId); } else { - // My local flag is copied to the peer's proxy flag - conn->ibQp->stageSend(conn->ibSignalEpochIdMr, &conn->ibProxySignalEpochIdMrInfo, sizeof(uint64_t), - /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/0, /*signaled=*/true); + // My local device flag is copied to the remote's proxy flag + conn->ibQp->stageSend(conn->ibSignalEpochIdMr, &conn->ibSignalEpochIdMrInfo, sizeof(uint64_t), + /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true); if ((ret = conn->ibQp->postSend()) != 0) { WARN("flag postSend failed: errno %d", ret); }