#include "mscclpp.hpp" #include #include #include #include #include #define CUDATHROW(cmd) \ do { \ cudaError_t err = cmd; \ if (err != cudaSuccess) { \ throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ } \ } while (false) mscclpp::Transport findIb(int localRank) { mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7}; return IBs[localRank]; } void test_communicator(int rank, int worldSize, int nranksPerNode) { auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId(); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); bootstrap->initialize(id); auto communicator = std::make_shared(bootstrap); if (bootstrap->getRank() == 0) std::cout << "Communicator initialization passed" << std::endl; std::vector> connections; auto myIbDevice = findIb(rank % nranksPerNode); for (int i = 0; i < worldSize; i++) { if (i != rank) { std::shared_ptr conn; if (i / nranksPerNode == rank / nranksPerNode) { conn = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { conn = communicator->connect(i, 0, myIbDevice); } connections.push_back(conn); } } communicator->connectionSetup(); if (bootstrap->getRank() == 0) std::cout << "Connection setup passed" << std::endl; int* devicePtr; int size = 1024; CUDATHROW(cudaMalloc(&devicePtr, size)); auto registeredMemory = communicator->registerMemory(devicePtr, size, mscclpp::Transport::CudaIpc | myIbDevice); for (int i = 0; i < worldSize; i++) { if (i != rank){ auto serialized = registeredMemory.serialize(); int serializedSize = serialized.size(); bootstrap->send(&serializedSize, sizeof(int), i, 0); bootstrap->send(serialized.data(), serializedSize, i, 1); } } std::vector registeredMemories; for (int i = 0; i < worldSize; i++) { if (i != rank){ int deserializedSize; bootstrap->recv(&deserializedSize, sizeof(int), i, 0); std::vector deserialized(deserializedSize); bootstrap->recv(deserialized.data(), deserializedSize, i, 1); auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); registeredMemories.push_back(std::move(deserializedRegisteredMemory)); } } if (bootstrap->getRank() == 0) std::cout << "Memory registration passed" << std::endl; assert(size % worldSize == 0); size_t writeSize = size / worldSize; size_t dataCount = size / sizeof(int); // std::vector hostBuffer(dataCount, 0); std::shared_ptr hostBuffer(new int[dataCount]); for (int i = 0; i < dataCount; i++) { hostBuffer[i] = rank; } CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); for (int i = 0; i < worldSize; i++) { if (i != rank) { int peerRankIndex = i < rank ? i : i - 1; auto conn = connections[peerRankIndex]; conn->write(registeredMemories[peerRankIndex], rank * writeSize, registeredMemory, rank * writeSize, writeSize); } } CUDATHROW(cudaDeviceSynchronize()); MPI_Barrier(MPI_COMM_WORLD); CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); size_t dataPerRank = writeSize / sizeof(int); for (int i = 0; i < dataCount; i++) { if (hostBuffer[i] != i / dataPerRank) { throw std::runtime_error("Data mismatch, connection write failed"); } } if (bootstrap->getRank() == 0) std::cout << "Connection write passed" << std::endl; CUDATHROW(cudaFree(devicePtr)); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; } int main(int argc, char** argv) { int rank, worldSize; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); MPI_Comm shmcomm; MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); int shmWorldSize; MPI_Comm_size(shmcomm, &shmWorldSize); int nranksPerNode = shmWorldSize; MPI_Comm_free(&shmcomm); test_communicator(rank, worldSize, nranksPerNode); MPI_Finalize(); return 0; }