From 17cbc84a14fcc6c1d0dbec09df331a0f979467b4 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Sun, 19 Mar 2023 06:35:32 +0000 Subject: [PATCH] both allgather algorithms --- tests/allgather_test.cu | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index b9035e28..d26e30f6 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -7,7 +7,8 @@ #include #include -#define RANKS_PER_NODE 2 +#define RANKS_PER_NODE 8 +#define KERNEL 1 #define MSCCLPPCHECK(call) do { \ mscclppResult_t res = call; \ @@ -61,7 +62,7 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) } // Each warp receives data from different ranks - +#if 0 // get a thread-local trigger and a request for waiting on it mscclppTrigger_t trig; mscclppRequest_t req = devConn.fifo.getTrigger(&trig); @@ -73,7 +74,24 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU) devConn.fifo.waitTrigger(req); // Wait for receiving data from remote rank - while (*proxyFlag == baseFlag) {} + while (*proxyFlag == baseFlag); +#else + for (int i = 1; i < world_size; i++){ + __syncthreads(); + if (remoteRank != ((rank+i) % world_size)) continue; + // get a thread-local trigger and a request for waiting on it + mscclppTrigger_t trig; + mscclppRequest_t req = devConn.fifo.getTrigger(&trig); + + // Trigger sending data, flag and synchronize after + devConn.fifo.setTrigger(trig, mscclppFlag | mscclppData | mscclppSync, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU*sizeof(int)); + + // Wait on the request to make sure it is safe to reuse buffer and flag + devConn.fifo.waitTrigger(req); + } + // Wait for receiving data from remote rank + while (*proxyFlag == baseFlag); +#endif } @@ -162,7 +180,7 @@ int main(int argc, const char *argv[]) int *data_d; uint64_t *flag_d; - size_t data_size = 1024*1; + size_t data_size = 1024*1024*1024; int nelemsPerGPU = data_size / sizeof(int) / world_size; CUDACHECK(cudaMalloc(&data_d, data_size)); CUDACHECK(cudaMalloc(&flag_d, sizeof(uint64_t)));