From d97bee6973587f7f4f9c78e123ec190560757b8b Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Wed, 29 Mar 2023 17:31:20 +0000 Subject: [PATCH] flush mechanism --- src/include/mscclpp.h | 16 ++++++++++------ tests/allgather_test.cu | 5 +++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 46a1d4a9..e748976d 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -89,17 +89,13 @@ struct mscclppDevConn __forceinline__ __device__ void signal() { epochIncrement(); - uint64_t curFifoHead = fifo.push(mscclppFlag | mscclppSync, 0, 0, 1); - while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead) - ; + fifo.push(mscclppFlag, 0, 0, 1); } __forceinline__ __device__ void putWithSignal(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) { epochIncrement(); - uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dstDataOffset, srcDataOffset, dataSize); - while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead) - ; + fifo.push(mscclppData | mscclppFlag, dstDataOffset, srcDataOffset, dataSize); } __forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize) @@ -107,6 +103,14 @@ struct mscclppDevConn putWithSignal(dataOffset, dataOffset, dataSize); } + __forceinline__ __device__ void flush() + { + epochIncrement(); + uint64_t curFifoHead = fifo.push(mscclppSync, 0, 0, 1); + while (*(volatile uint64_t*)fifo.triggerFifoTail <= curFifoHead) + ; + } + __forceinline__ __device__ void wait() { (*recvEpochId) += 1; diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index e9f8331c..04eebafe 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -53,11 +53,11 @@ __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, i // this thread's role is a sender role // put your data asynchronously - devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); // make sure everyone is put their data before some thread randomly blocks everyone else in signal __syncthreads(); // push with flag and sync to make sure the data is received - devConn.signal(); + devConn.flush(); // this thread's role is a receiver role. wait on the semaphore to make sure the data is ready devConn.wait(); @@ -77,6 +77,7 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i continue; // put your data to GPU (rank+i) % world_size and signal all in one call devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + devConn.flush(); } // all connections wait for the signal from the sender devConn.wait();