mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-14 02:03:03 +00:00
Thread-safe trigger
This commit is contained in:
@@ -78,7 +78,23 @@ struct mscclppDevConn {
|
||||
#ifdef __CUDACC__
|
||||
__forceinline__ __device__ mscclppTrigger *getTrigger() {
|
||||
unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
|
||||
return &this->trigger[curFifoHead];
|
||||
return &this->triggerFifo[curFifoHead];
|
||||
}
|
||||
|
||||
__forceinline__ __device__ mscclppTrigger *acquireTrigger() {
|
||||
unsigned int *cnt = this->triggerFifoCounter;
|
||||
unsigned int old = atomicAdd(cnt, 1);
|
||||
while (old >= MSCCLPP_PROXY_FIFO_SIZE) {
|
||||
atomicSub(cnt, 1);
|
||||
while (*(volatile unsigned int *)cnt >= MSCCLPP_PROXY_FIFO_SIZE) {}
|
||||
old = atomicAdd(cnt, 1);
|
||||
}
|
||||
// Up to MSCCLPP_PROXY_FIFO_SIZE threads can enter here at the same time
|
||||
return getTrigger();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void releaseTrigger() {
|
||||
atomicSub(this->triggerFifoCounter, 1);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
|
||||
@@ -104,7 +120,8 @@ struct mscclppDevConn {
|
||||
uint64_t* remoteFlag;
|
||||
|
||||
unsigned int* triggerFifoHead; // indicates the tail of the fifo. only accessible by the gpu. for parallel, access use atomic
|
||||
mscclppTrigger* trigger;
|
||||
mscclppTrigger* triggerFifo;
|
||||
unsigned int* triggerFifoCounter;
|
||||
uint64_t* proxyFlag;
|
||||
int connId;
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ struct mscclppProxyState {
|
||||
// cpuTriggerFifoTail indicates where CPU needs to read the head of the fifo.
|
||||
unsigned int cpuTriggerFifoTail;
|
||||
unsigned int *gpuTriggerFifoHead;
|
||||
unsigned int *gpuTriggerFifoCounter;
|
||||
void *cpuTriggerFifoGdrDesc;
|
||||
// NULL for the P2P proxy.
|
||||
struct mscclppIbContext *ibContext;
|
||||
|
||||
@@ -211,6 +211,7 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i
|
||||
MSCCLPPCHECK(mscclppGdrCudaCalloc(&proxyState->cpuTriggerFifo, &proxyState->gpuTriggerFifo,
|
||||
MSCCLPP_PROXY_FIFO_SIZE, &proxyState->cpuTriggerFifoGdrDesc));
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoHead, 1));
|
||||
MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoCounter, 1));
|
||||
proxyState->ibContext = conn->ibCtx;
|
||||
comm->proxyState[i] = proxyState;
|
||||
break;
|
||||
@@ -231,8 +232,9 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i
|
||||
conn->devConn->localFlag = localFlag;
|
||||
conn->devConn->tag = tag;
|
||||
conn->devConn->connId = comm->nConns;
|
||||
conn->devConn->trigger = proxyState->gpuTriggerFifo;
|
||||
conn->devConn->triggerFifo = proxyState->gpuTriggerFifo;
|
||||
conn->devConn->triggerFifoHead = proxyState->gpuTriggerFifoHead;
|
||||
conn->devConn->triggerFifoCounter = proxyState->gpuTriggerFifoCounter;
|
||||
|
||||
comm->nConns++;
|
||||
return mscclppSuccess;
|
||||
|
||||
@@ -54,7 +54,6 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
|
||||
volatile uint64_t *remoteFlag = devConn.remoteFlag;
|
||||
#endif
|
||||
volatile uint64_t *proxyFlag = devConn.proxyFlag;
|
||||
mscclppTrigger *trig = devConn.getTrigger();
|
||||
|
||||
uint64_t baseFlag = *localFlag;
|
||||
|
||||
@@ -65,18 +64,21 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
|
||||
*localFlag = baseFlag + 1;
|
||||
}
|
||||
|
||||
// Thread-safely obtain the head trigger
|
||||
mscclppTrigger *trig = devConn.acquireTrigger();
|
||||
|
||||
// Each warp receives data from different ranks
|
||||
#if (USE_DMA_FOR_P2P == 1)
|
||||
|
||||
// Prevent overwriting trigger
|
||||
devConn.waitTrigger(trig);
|
||||
|
||||
// Trigger sending data and flag
|
||||
devConn.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
|
||||
|
||||
// Wait until the proxy have sent my data and flag
|
||||
devConn.waitTrigger(trig);
|
||||
|
||||
// Inform other threads that the tail trigger just became idle
|
||||
devConn.releaseTrigger();
|
||||
|
||||
// Wait for receiving data from remote rank
|
||||
while (*proxyFlag == baseFlag) {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user