From 518f325225ccece587aba8a0874a277d22cb4cc8 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Wed, 3 May 2023 22:45:47 +0000 Subject: [PATCH] kernel 2 is also performant --- tests/allgather_test_cpp.cu | 14 +++++++++++--- tests/communicator_test_cpp.cu | 17 ----------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index aaff931c..ad473f8f 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -84,7 +84,7 @@ __device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, in if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { // put your data to GPU (rank+i) % nranksPerNode and signal in one call if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush(offset, size); + devChan.putWithSignal(offset, size); } // wait for the data from GPU (rank-i) % nranksPerNode to arrive if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { @@ -100,6 +100,9 @@ __device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int ra { localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + if (remoteRank / nranksPerNode == rank / nranksPerNode) + if ((threadIdx.x % 32) == 0) + devChan.flush(); } __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, @@ -127,7 +130,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), + devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) devChan.wait(); @@ -147,7 +150,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * + devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) @@ -163,6 +166,11 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); } + + if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) { + if ((threadIdx.x % 32) == 0) + devChan.flush(); + } } __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) diff --git a/tests/communicator_test_cpp.cu b/tests/communicator_test_cpp.cu index fcdd0f5a..56c8592e 100644 --- a/tests/communicator_test_cpp.cu +++ b/tests/communicator_test_cpp.cu @@ -39,23 +39,6 @@ void register_all_memories(mscclpp::Communicator& communicator, int rank, int wo remoteMemory[i] = futureRemoteMemory[i].get(); } } - - - // auto serialized = localMemory.serialize(); - // int serializedSize = serialized.size(); - // for (int i = 0; i < worldSize; i++) { - // if (i != rank){ - // communicator.bootstrapper()->send(serialized.data(), serializedSize, i, 0); - // } - // } - // for (int i = 0; i < worldSize; i++) { - // if (i != rank){ - // std::vector deserialized(serializedSize); - // communicator.bootstrapper()->recv(deserialized.data(), serializedSize, i, 0); - // auto remote = mscclpp::RegisteredMemory::deserialize(deserialized); - // remoteMemory[i] = remote; - // } - // } } void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){