From 8fc822c8489fe6653cfbef1a2b9b5cf633522fc0 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 25 Apr 2023 22:26:48 +0000 Subject: [PATCH] more tests for bootstrap --- src/bootstrap/bootstrap.cc | 22 ++++++ src/include/mscclpp.hpp | 4 ++ tests/bootstrap_test_cpp.cc | 140 +++++++++++++++++++++++++++--------- 3 files changed, 133 insertions(+), 33 deletions(-) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 51ac66d9..dfce50b4 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -79,6 +79,8 @@ public: void establishConnections(); UniqueId createUniqueId(); UniqueId getUniqueId() const; + int getRank(); + int getNranks(); void allGather(void* allData, int size); void send(void* data, int size, int peer, int tag); void recv(void* data, int size, int peer, int tag); @@ -137,6 +139,16 @@ UniqueId Bootstrap::Impl::createUniqueId() return getUniqueId(); } +int Bootstrap::Impl::getRank() +{ + return rank_; +} + +int Bootstrap::Impl::getNranks() +{ + return nRanks_; +} + void Bootstrap::Impl::initialize(const UniqueId uniqueId) { netInit(""); @@ -455,6 +467,16 @@ MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const return pimpl_->getUniqueId(); } +MSCCLPP_API_CPP int Bootstrap::getRank() +{ + return pimpl_->getRank(); +} + +MSCCLPP_API_CPP int Bootstrap::getNranks() +{ + return pimpl_->getNranks(); +} + MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 12ac7873..bcbbf41d 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -30,6 +30,8 @@ class BaseBootstrap public: BaseBootstrap(){}; virtual ~BaseBootstrap() = default; + virtual int getRank() = 0; + virtual int getNranks() = 0; virtual void send(void* data, int size, int peer, int tag) = 0; virtual void recv(void* data, int size, int peer, int tag) = 0; virtual void allGather(void* allData, int size) = 0; @@ -47,6 +49,8 @@ public: void initialize(UniqueId uniqueId); void initialize(std::string ipPortPair); + int getRank() override; + int getNranks() override; void send(void* data, int size, int peer, int tag) override; void recv(void* data, int size, int peer, int tag) override; void allGather(void* allData, int size) override; diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index e7160edd..34e58b59 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -1,49 +1,41 @@ #include "mscclpp.hpp" #include - +#include +#include #include -int main() -{ - int rank, worldSize; - MPI_Init(NULL, NULL); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - - std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); - // bootstrap->Initialize("costsim-dev-00000A:50000"); - mscclpp::UniqueId id; - if (rank == 0) - id = bootstrap->createUniqueId(); - MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); - bootstrap->initialize(id); - - std::vector tmp(worldSize, 0); - tmp[rank] = rank + 1; +void test_allgather(std::shared_ptr bootstrap){ + std::vector tmp(bootstrap->getNranks(), 0); + tmp[bootstrap->getRank()] = bootstrap->getRank() + 1; bootstrap->allGather(tmp.data(), sizeof(int)); - for (int i = 0; i < worldSize; i++) { - if (tmp[i] != i + 1) - printf("error AllGather: rank %d: tmp[%d] = %d\n", rank, i, tmp[i]); + for (int i = 0; i < bootstrap->getNranks(); i++) { + assert(tmp[i] == i + 1); } - printf("rank %d: AllGather test passed!\n", rank); + if (bootstrap->getRank() == 0) + std::cout << "AllGather test passed!" << std::endl; +} +void test_barrier(std::shared_ptr bootstrap){ bootstrap->barrier(); - printf("rank %d: Barrier test passed!\n", rank); + if (bootstrap->getRank() == 0) + std::cout << "Barrier test passed!" << std::endl; +} - for (int i = 0; i < worldSize; i++) { - if (i == rank) +void test_sendrecv(std::shared_ptr bootstrap){ + for (int i = 0; i < bootstrap->getNranks(); i++) { + if (bootstrap->getRank() == 0) continue; - int msg1 = (rank + 1) * 3; - int msg2 = (rank + 1) * 3 + 1; - int msg3 = (rank + 1) * 3 + 2; + int msg1 = (bootstrap->getRank() + 1) * 3; + int msg2 = (bootstrap->getRank() + 1) * 3 + 1; + int msg3 = (bootstrap->getRank() + 1) * 3 + 2; bootstrap->send(&msg1, sizeof(int), i, 0); bootstrap->send(&msg2, sizeof(int), i, 1); bootstrap->send(&msg3, sizeof(int), i, 2); } - for (int i = 0; i < worldSize; i++) { - if (i == rank) + for (int i = 0; i < bootstrap->getNranks(); i++) { + if (i == bootstrap->getRank()) continue; int msg1 = 0; int msg2 = 0; @@ -52,10 +44,92 @@ int main() bootstrap->recv(&msg2, sizeof(int), i, 1); bootstrap->recv(&msg3, sizeof(int), i, 2); bootstrap->recv(&msg1, sizeof(int), i, 0); - if (msg1 != (i + 1) * 3 || msg2 != (i + 1) * 3 + 1 || msg3 != (i + 1) * 3 + 2) - printf("error Send/Recv: rank %d: msg1 = %d, msg2 = %d\n", rank, msg1, msg2); + assert(msg1 == (i + 1) * 3); + assert(msg2 == (i + 1) * 3 + 1); + assert(msg3 == (i + 1) * 3 + 2); } - printf("rank %d: Send/Recv test passed!\n", rank); + if (bootstrap->getRank() == 0) + std::cout << "Send/Recv test passed!" << std::endl; +} + +void test_all(std::shared_ptr bootstrap){ + test_allgather(bootstrap); + test_barrier(bootstrap); + // test_sendrecv(bootstrap); +} + +void test_mscclpp_bootstrap_with_id(int rank, int worldSize){ + std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl; +} + +void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar){ + std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); + bootstrap->initialize(ipPortPiar); + + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl; +} + +class MPIBootstrap : public mscclpp::BaseBootstrap { +public: + MPIBootstrap() : BaseBootstrap() {} + int getRank() override { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + return rank; + } + int getNranks() override { + int worldSize; + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + return worldSize; + } + void allGather(void *sendbuf, int size) override { + MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); + } + void barrier() override { + MPI_Barrier(MPI_COMM_WORLD); + } + void send(void *sendbuf, int size, int dest, int tag) override { + MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD); + } + void recv(void *recvbuf, int size, int source, int tag) override { + MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + } +}; + +void test_mpi_bootstrap(){ + std::shared_ptr bootstrap(new MPIBootstrap()); + test_all(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "--- MPI Bootstrap test passed! ---" << std::endl; +} + +int main(int argc, char **argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + if (argc > 2){ + if (rank == 0) + std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl; + MPI_Finalize(); + return 0; + } + test_mscclpp_bootstrap_with_id(rank, worldSize); + if (argc == 2) + test_mscclpp_bootstrap_with_ip_port_pair(rank, worldSize, argv[1]); + test_mpi_bootstrap(); MPI_Finalize(); return 0;