mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
flush mechanism
This commit is contained in:
@@ -89,17 +89,13 @@ struct mscclppDevConn
|
||||
__forceinline__ __device__ void signal()
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(mscclppFlag | mscclppSync, 0, 0, 1);
|
||||
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
fifo.push(mscclppFlag, 0, 0, 1);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize)
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize);
|
||||
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
fifo.push(mscclppData | mscclppFlag, dstDataOffset, srcDataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize)
|
||||
@@ -107,6 +103,14 @@ struct mscclppDevConn
|
||||
putWithSignal(dataOffset, dataOffset, dataSize);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void flush()
|
||||
{
|
||||
epochIncrement();
|
||||
uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1);
|
||||
while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead)
|
||||
;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void wait()
|
||||
{
|
||||
(*recvEpochId) += 1;
|
||||
|
||||
@@ -53,11 +53,11 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
|
||||
// this thread's role is a sender role
|
||||
// put your data asynchronously
|
||||
devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
|
||||
__syncthreads();
|
||||
// push with flag and sync to make sure the data is received
|
||||
devConn.signal();
|
||||
devConn.flush();
|
||||
|
||||
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
|
||||
devConn.wait();
|
||||
@@ -77,6 +77,7 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
continue;
|
||||
// put your data to GPU (rank+i) % world_size and signal all in one call
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
devConn.flush();
|
||||
}
|
||||
// all connections wait for the signal from the sender
|
||||
devConn.wait();
|
||||
|
||||
Reference in New Issue
Block a user