diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index fe21b0b0..5c76b7af 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -78,7 +78,23 @@ struct mscclppDevConn { #ifdef __CUDACC__ __forceinline__ __device__ mscclppTrigger *getTrigger() { unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1); - return &this->trigger[curFifoHead]; + return &this->triggerFifo[curFifoHead]; + } + + __forceinline__ __device__ mscclppTrigger *acquireTrigger() { + unsigned int *cnt = this->triggerFifoCounter; + unsigned int old = atomicAdd(cnt, 1); + while (old >= MSCCLPP_PROXY_FIFO_SIZE) { + atomicSub(cnt, 1); + while (*(volatile unsigned int *)cnt >= MSCCLPP_PROXY_FIFO_SIZE) {} + old = atomicAdd(cnt, 1); + } + // Up to MSCCLPP_PROXY_FIFO_SIZE threads can enter here at the same time + return getTrigger(); + } + + __forceinline__ __device__ void releaseTrigger() { + atomicSub(this->triggerFifoCounter, 1); } __forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) { @@ -104,7 +120,8 @@ struct mscclppDevConn { uint64_t* remoteFlag; unsigned int* triggerFifoHead; // indicates the tail of the fifo. only accessible by the gpu. for parallel, access use atomic - mscclppTrigger* trigger; + mscclppTrigger* triggerFifo; + unsigned int* triggerFifoCounter; uint64_t* proxyFlag; int connId; }; diff --git a/src/include/proxy.h b/src/include/proxy.h index 61f9ea24..7af16279 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -21,6 +21,7 @@ struct mscclppProxyState { // cpuTriggerFifoTail indicates where CPU needs to read the head of the fifo. unsigned int cpuTriggerFifoTail; unsigned int *gpuTriggerFifoHead; + unsigned int *gpuTriggerFifoCounter; void *cpuTriggerFifoGdrDesc; // NULL for the P2P proxy. struct mscclppIbContext *ibContext; diff --git a/src/init.cc b/src/init.cc index 5cb2cdcd..3b718f88 100644 --- a/src/init.cc +++ b/src/init.cc @@ -211,6 +211,7 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i MSCCLPPCHECK(mscclppGdrCudaCalloc(&proxyState->cpuTriggerFifo, &proxyState->gpuTriggerFifo, MSCCLPP_PROXY_FIFO_SIZE, &proxyState->cpuTriggerFifoGdrDesc)); MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoHead, 1)); + MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoCounter, 1)); proxyState->ibContext = conn->ibCtx; comm->proxyState[i] = proxyState; break; @@ -231,8 +232,9 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i conn->devConn->localFlag = localFlag; conn->devConn->tag = tag; conn->devConn->connId = comm->nConns; - conn->devConn->trigger = proxyState->gpuTriggerFifo; + conn->devConn->triggerFifo = proxyState->gpuTriggerFifo; conn->devConn->triggerFifoHead = proxyState->gpuTriggerFifoHead; + conn->devConn->triggerFifoCounter = proxyState->gpuTriggerFifoCounter; comm->nConns++; return mscclppSuccess; diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index a6bf5ea4..57bb57b3 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -54,7 +54,6 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) volatile uint64_t *remoteFlag = devConn.remoteFlag; #endif volatile uint64_t *proxyFlag = devConn.proxyFlag; - mscclppTrigger *trig = devConn.getTrigger(); uint64_t baseFlag = *localFlag; @@ -65,18 +64,21 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) *localFlag = baseFlag + 1; } + // Thread-safely obtain the head trigger + mscclppTrigger *trig = devConn.acquireTrigger(); + // Each warp receives data from different ranks #if (USE_DMA_FOR_P2P == 1) - // Prevent overwriting trigger - devConn.waitTrigger(trig); - // Trigger sending data and flag devConn.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int)); // Wait until the proxy have sent my data and flag devConn.waitTrigger(trig); + // Inform other threads that the tail trigger just became idle + devConn.releaseTrigger(); + // Wait for receiving data from remote rank while (*proxyFlag == baseFlag) {}