From 6ea460bb3a85c3688467d84c3f3f1290907fa2a8 Mon Sep 17 00:00:00 2001 From: Madan Musuvathi Date: Wed, 22 Mar 2023 18:16:42 +0000 Subject: [PATCH] fusing signal with sync --- src/include/mscclpp.h | 17 +++++++---------- tests/allgather_test.cu | 15 +++++---------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 3132b164..3827f9fd 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -42,13 +42,12 @@ union alignas(16) mscclppTrigger { } fields; }; -typedef uint64_t mscclppRequest_t; typedef mscclppTrigger* mscclppTrigger_t; struct mscclppConcurrentFifo { #ifdef __CUDACC__ - __forceinline__ __device__ mscclppRequest_t push(uint64_t type, uint64_t dataOffset, uint64_t dataSize){ + __forceinline__ __device__ uint64_t push(uint64_t type, uint64_t dataOffset, uint64_t dataSize){ uint64_t curFifoHead = atomicAdd((unsigned long long int*)this->triggerFifoHead,1); while (curFifoHead >= MSCCLPP_PROXY_FIFO_SIZE + *((volatile uint64_t*)this->triggerFifoTail)); auto valptr = &(this->triggerFifo[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE].value); @@ -121,18 +120,16 @@ struct mscclppDevConn { __forceinline__ __device__ void put(uint64_t dataOffset, uint64_t dataSize){ fifo.push(mscclppData, dataOffset, dataSize); } - __forceinline__ __device__ mscclppRequest_t signal(){ + __forceinline__ __device__ void signal(){ epochIncrement(); - return fifo.push(mscclppFlag | mscclppSync, 1, 1); + uint64_t curFifoHead = fifo.push(mscclppFlag | mscclppSync, 1, 1); + while (*(volatile uint64_t *)fifo.triggerFifoTail <= curFifoHead); } - __forceinline__ __device__ mscclppRequest_t putWithSignal(uint64_t dataOffset, uint64_t dataSize){ + __forceinline__ __device__ void putWithSignal(uint64_t dataOffset, uint64_t dataSize){ epochIncrement(); - return fifo.push(mscclppData | mscclppFlag | mscclppSync, dataOffset, dataSize); - } - - __forceinline__ __device__ void sync(mscclppRequest_t req) { - while (*(volatile uint64_t *)fifo.triggerFifoTail <= req); + uint64_t curFifoHead = fifo.push(mscclppData | mscclppFlag | mscclppSync, dataOffset, dataSize); + while (*(volatile uint64_t *)fifo.triggerFifoTail <= curFifoHead); } __forceinline__ __device__ void wait(){ diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index b37d286e..83d17505 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -7,7 +7,7 @@ #include #include -#define RANKS_PER_NODE 8 +#define RANKS_PER_NODE 1 #define MSCCLPPCHECK(call) do { \ mscclppResult_t res = call; \ @@ -59,17 +59,14 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) // } // Each warp receives data from different ranks -#if 0 +#if 1 // push your data asynchronously devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int)); // push with flag and sync to make sure the data is received - auto req = devConn.signal(); - - devConn.sync(req); + devConn.signal(); devConn.wait(); - //while (*proxyFlag == baseFlag); #else for (int i = 1; i < world_size; i++){ @@ -79,11 +76,9 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) devConn.put(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int)); // push with flag and sync to make sure the data is received - auto req = devConn.signal(); - - devConn.sync(req); - + devConn.signal(); } + devConn.wait(); // Wait for receiving data from remote rank // while (*proxyFlag == baseFlag);