Thread-safe trigger

This commit is contained in:
Changho Hwang
2023-03-17 09:46:23 +00:00
parent 2061ea91f7
commit 67dbbd1692
4 changed files with 29 additions and 7 deletions

View File

@@ -78,7 +78,23 @@ struct mscclppDevConn {
#ifdef __CUDACC__
__forceinline__ __device__ mscclppTrigger *getTrigger() {
unsigned int curFifoHead = atomicInc(this->triggerFifoHead, MSCCLPP_PROXY_FIFO_SIZE - 1);
return &this->trigger[curFifoHead];
return &this->triggerFifo[curFifoHead];
}
__forceinline__ __device__ mscclppTrigger *acquireTrigger() {
unsigned int *cnt = this->triggerFifoCounter;
unsigned int old = atomicAdd(cnt, 1);
while (old >= MSCCLPP_PROXY_FIFO_SIZE) {
atomicSub(cnt, 1);
while (*(volatile unsigned int *)cnt >= MSCCLPP_PROXY_FIFO_SIZE) {}
old = atomicAdd(cnt, 1);
}
// Up to MSCCLPP_PROXY_FIFO_SIZE threads can enter here at the same time
return getTrigger();
}
__forceinline__ __device__ void releaseTrigger() {
atomicSub(this->triggerFifoCounter, 1);
}
__forceinline__ __device__ void setTrigger(mscclppTrigger *trig, uint64_t type, uint64_t dataOffset, uint64_t dataSize) {
@@ -104,7 +120,8 @@ struct mscclppDevConn {
uint64_t* remoteFlag;
unsigned int* triggerFifoHead; // indicates the tail of the fifo. only accessible by the gpu. for parallel, access use atomic
mscclppTrigger* trigger;
mscclppTrigger* triggerFifo;
unsigned int* triggerFifoCounter;
uint64_t* proxyFlag;
int connId;
};

View File

@@ -21,6 +21,7 @@ struct mscclppProxyState {
// cpuTriggerFifoTail indicates where CPU needs to read the head of the fifo.
unsigned int cpuTriggerFifoTail;
unsigned int *gpuTriggerFifoHead;
unsigned int *gpuTriggerFifoCounter;
void *cpuTriggerFifoGdrDesc;
// NULL for the P2P proxy.
struct mscclppIbContext *ibContext;

View File

@@ -211,6 +211,7 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i
MSCCLPPCHECK(mscclppGdrCudaCalloc(&proxyState->cpuTriggerFifo, &proxyState->gpuTriggerFifo,
MSCCLPP_PROXY_FIFO_SIZE, &proxyState->cpuTriggerFifoGdrDesc));
MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoHead, 1));
MSCCLPPCHECK(mscclppCudaCalloc(&proxyState->gpuTriggerFifoCounter, 1));
proxyState->ibContext = conn->ibCtx;
comm->proxyState[i] = proxyState;
break;
@@ -231,8 +232,9 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, i
conn->devConn->localFlag = localFlag;
conn->devConn->tag = tag;
conn->devConn->connId = comm->nConns;
conn->devConn->trigger = proxyState->gpuTriggerFifo;
conn->devConn->triggerFifo = proxyState->gpuTriggerFifo;
conn->devConn->triggerFifoHead = proxyState->gpuTriggerFifoHead;
conn->devConn->triggerFifoCounter = proxyState->gpuTriggerFifoCounter;
comm->nConns++;
return mscclppSuccess;

View File

@@ -54,7 +54,6 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
volatile uint64_t *remoteFlag = devConn.remoteFlag;
#endif
volatile uint64_t *proxyFlag = devConn.proxyFlag;
mscclppTrigger *trig = devConn.getTrigger();
uint64_t baseFlag = *localFlag;
@@ -65,18 +64,21 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU)
*localFlag = baseFlag + 1;
}
// Thread-safely obtain the head trigger
mscclppTrigger *trig = devConn.acquireTrigger();
// Each warp receives data from different ranks
#if (USE_DMA_FOR_P2P == 1)
// Prevent overwriting trigger
devConn.waitTrigger(trig);
// Trigger sending data and flag
devConn.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int));
// Wait until the proxy have sent my data and flag
devConn.waitTrigger(trig);
// Inform other threads that the tail trigger just became idle
devConn.releaseTrigger();
// Wait for receiving data from remote rank
while (*proxyFlag == baseFlag) {}