From c24896b62f4d7e906bf2121303837cbae0bd3abd Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 04:23:44 +0000 Subject: [PATCH] bootstrap to the communicator --- src/communicator.cc | 22 +--------------------- tests/bootstrap_test_cpp.cc | 4 ++-- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 81753fb6..02ee7a87 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -54,15 +54,7 @@ static mscclppTransport_t transportToCStyle(TransportFlags flags) { } } -MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique()) { - mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); -} - -MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique()) { - static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); - mscclppUniqueId *cstyle_id = reinterpret_cast(&id); - mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); -} +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { mscclppBootstrapAllGather(pimpl->comm, data, size); @@ -100,16 +92,4 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() { } } -MSCCLPP_API_CPP int Communicator::rank() { - int result; - mscclppCommRank(pimpl->comm, &result); - return result; -} - -MSCCLPP_API_CPP int Communicator::size() { - int result; - mscclppCommSize(pimpl->comm, &result); - return result; -} - } // namespace mscclpp diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index 6c29e369..bdde8467 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -24,7 +24,7 @@ void test_barrier(std::shared_ptr bootstrap){ void test_sendrecv(std::shared_ptr bootstrap){ for (int i = 0; i < bootstrap->getNranks(); i++) { - if (bootstrap->getRank() == 0) + if (bootstrap->getRank() == i) continue; int msg1 = (bootstrap->getRank() + 1) * 3; int msg2 = (bootstrap->getRank() + 1) * 3 + 1; @@ -35,7 +35,7 @@ void test_sendrecv(std::shared_ptr bootstrap){ } for (int i = 0; i < bootstrap->getNranks(); i++) { - if (i == bootstrap->getRank()) + if (bootstrap->getRank() == i) continue; int msg1 = 0; int msg2 = 0;