flush mechanism

This commit is contained in:
Saeed Maleki
2023-03-29 17:31:20 +00:00
parent 7a0962e4be
commit d97bee6973
2 changed files with 13 additions and 8 deletions

View File

@@ -89,17 +89,13 @@ struct mscclppDevConn
__forceinline__ __device__ void signal()
{
epochIncrement();
uint64_t curFifoHead = fifo.push(mscclppFlag | mscclppSync, 0, 0, 1);
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
;
fifo.push(mscclppFlag, 0, 0, 1);
}
__forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
{
epochIncrement();
uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize);
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
;
fifo.push(mscclppData | mscclppFlag, dstDataOffset, srcDataOffset, dataSize);
}
__forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize)
@@ -107,6 +103,14 @@ struct mscclppDevConn
putWithSignal(dataOffset, dataOffset, dataSize);
}
__forceinline__ __device__ void flush()
{
epochIncrement();
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
;
}
__forceinline__ __device__ void wait()
{
(*recvEpochId) += 1;

View File

@@ -53,11 +53,11 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i
// this thread's role is a sender role
// put your data asynchronously
devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
__syncthreads();
// push with flag and sync to make sure the data is received
devConn.signal();
devConn.flush();
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
devConn.wait();
@@ -77,6 +77,7 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i
continue;
// put your data to GPU (rank+i) % world_size and signal all in one call
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
devConn.flush();
}
// all connections wait for the signal from the sender
devConn.wait();