mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
connect test
This commit is contained in:
2
Makefile
2
Makefile
@@ -149,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS))
|
||||
|
||||
TESTSDIR := tests
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc) # allgather_test_cpp.cu
|
||||
TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc) # allgather_test_cpp.cu
|
||||
TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS))
|
||||
TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%)
|
||||
TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS))
|
||||
|
||||
48
tests/communicator_test_cpp.cc
Normal file
48
tests/communicator_test_cpp.cc
Normal file
@@ -0,0 +1,48 @@
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
|
||||
void test_communicator(int rank, int worldSize, int nranksPerNode){
|
||||
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0)
|
||||
id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->initialize(id);
|
||||
|
||||
auto communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
for (int i = 0; i < worldSize; i++){
|
||||
if (i != rank){
|
||||
if (i % nranksPerNode == rank % nranksPerNode)
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc);
|
||||
else
|
||||
auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB);
|
||||
}
|
||||
}
|
||||
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
int rank, worldSize;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
MPI_Comm shmcomm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
|
||||
int shmWorldSize;
|
||||
MPI_Comm_size(shmcomm, &shmWorldSize);
|
||||
int nranksPerNode = shmWorldSize;
|
||||
MPI_Comm_free(&shmcomm);
|
||||
|
||||
test_communicator(rank, worldSize, nranksPerNode);
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user