#include "mscclpp.hpp" #include #include #include #include mscclpp::TransportFlags findIb(int localRank){ mscclpp::TransportFlags IBs[] = { mscclpp::TransportIB0, mscclpp::TransportIB1, mscclpp::TransportIB2, mscclpp::TransportIB3, mscclpp::TransportIB4, mscclpp::TransportIB5, mscclpp::TransportIB6, mscclpp::TransportIB7 }; return IBs[localRank]; } void test_communicator(int rank, int worldSize, int nranksPerNode){ auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId(); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); bootstrap->initialize(id); auto communicator = std::make_shared(bootstrap); for (int i = 0; i < worldSize; i++){ if (i != rank){ if (i / nranksPerNode == rank / nranksPerNode){ printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode); auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); } else { auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); } } } communicator->connectionSetup(); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; } int main(int argc, char **argv) { int rank, worldSize; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); MPI_Comm shmcomm; MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); int shmWorldSize; MPI_Comm_size(shmcomm, &shmWorldSize); int nranksPerNode = shmWorldSize; MPI_Comm_free(&shmcomm); test_communicator(rank, worldSize, nranksPerNode); MPI_Finalize(); return 0; }