diff --git a/src/communicator.cc b/src/communicator.cc index 79e45f8d..35936862 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -21,6 +21,7 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_( INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); rankToHash_[bootstrap->getRank()] = hostHash; bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); + comm->rank = bootstrap->getRank(); } Communicator::Impl::~Impl() diff --git a/src/include/connection.hpp b/src/include/connection.hpp index f957c8a1..42ca6d47 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -13,8 +13,8 @@ namespace mscclpp { class ConnectionBase : public Connection { public: - virtual void startSetup(std::shared_ptr bootstrap){}; - virtual void endSetup(std::shared_ptr bootstrap){}; + virtual void startSetup(std::shared_ptr){}; + virtual void endSetup(std::shared_ptr){}; }; class CudaIpcConnection : public ConnectionBase diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 7fccf57b..a0b12e43 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -35,14 +35,17 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Communicator initialization passed" << std::endl; + std::vector> connections; auto myIbDevice = findIb(rank % nranksPerNode); for (int i = 0; i < worldSize; i++) { if (i != rank) { + std::shared_ptr conn; if (i / nranksPerNode == rank / nranksPerNode) { - auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); + conn = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { - auto connect = communicator->connect(i, 0, myIbDevice); + conn = communicator->connect(i, 0, myIbDevice); } + connections.push_back(conn); } } communicator->connectionSetup(); @@ -63,20 +66,52 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) bootstrap->send(serialized.data(), serializedSize, i, 1); } } + std::vector registeredMemories; for (int i = 0; i < worldSize; i++) { if (i != rank){ int deserializedSize; bootstrap->recv(&deserializedSize, sizeof(int), i, 0); std::vector deserialized(deserializedSize); bootstrap->recv(deserialized.data(), deserializedSize, i, 1); - // auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + registeredMemories.push_back(std::move(deserializedRegisteredMemory)); } } + if (bootstrap->getRank() == 0) + std::cout << "Memory registration passed" << std::endl; + + assert(size % worldSize == 0); + size_t writeSize = size / worldSize; + size_t dataCount = size / sizeof(int); + // std::vector hostBuffer(dataCount, 0); + std::shared_ptr hostBuffer(new int[dataCount]); + for (int i = 0; i < dataCount; i++) { + hostBuffer[i] = rank; + } + CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); + + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + int peerRankIndex = i < rank ? i : i - 1; + auto conn = connections[peerRankIndex]; + conn->write(registeredMemories[peerRankIndex], rank * writeSize, registeredMemory, rank * writeSize, writeSize); + } + } + CUDATHROW(cudaDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); + size_t dataPerRank = writeSize / sizeof(int); + for (int i = 0; i < dataCount; i++) { + if (hostBuffer[i] != i / dataPerRank) { + throw std::runtime_error("Data mismatch, connection write failed"); + } + } if (bootstrap->getRank() == 0) - std::cout << "Memory registeration passed" << std::endl; + std::cout << "Connection write passed" << std::endl; + CUDATHROW(cudaFree(devicePtr)); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; }