diff --git a/Makefile b/Makefile index b2d2cceb..950751d7 100644 --- a/Makefile +++ b/Makefile @@ -149,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc) # allgather_test_cpp.cu +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc) # allgather_test_cpp.cu TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc new file mode 100644 index 00000000..fc3a72e8 --- /dev/null +++ b/tests/communicator_test_cpp.cc @@ -0,0 +1,48 @@ +#include "mscclpp.hpp" + +#include +#include +#include +#include + +void test_communicator(int rank, int worldSize, int nranksPerNode){ + auto bootstrap = std::make_shared(rank, worldSize); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + auto communicator = std::make_shared(bootstrap); + for (int i = 0; i < worldSize; i++){ + if (i != rank){ + if (i % nranksPerNode == rank % nranksPerNode) + auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); + else + auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB); + } + } + + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; +} + + +int main(int argc, char **argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmWorldSize; + MPI_Comm_size(shmcomm, &shmWorldSize); + int nranksPerNode = shmWorldSize; + MPI_Comm_free(&shmcomm); + + test_communicator(rank, worldSize, nranksPerNode); + + MPI_Finalize(); + return 0; +} \ No newline at end of file