diff --git a/src/communicator.cc b/src/communicator.cc index 81753fb6..02ee7a87 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -54,15 +54,7 @@ static mscclppTransport_t transportToCStyle(TransportFlags flags) { } } -MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique()) { - mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); -} - -MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique()) { - static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); - mscclppUniqueId *cstyle_id = reinterpret_cast(&id); - mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); -} +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { mscclppBootstrapAllGather(pimpl->comm, data, size); @@ -100,16 +92,4 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() { } } -MSCCLPP_API_CPP int Communicator::rank() { - int result; - mscclppCommRank(pimpl->comm, &result); - return result; -} - -MSCCLPP_API_CPP int Communicator::size() { - int result; - mscclppCommSize(pimpl->comm, &result); - return result; -} - } // namespace mscclpp diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index 6c29e369..bdde8467 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -24,7 +24,7 @@ void test_barrier(std::shared_ptr bootstrap){ void test_sendrecv(std::shared_ptr bootstrap){ for (int i = 0; i < bootstrap->getNranks(); i++) { - if (bootstrap->getRank() == 0) + if (bootstrap->getRank() == i) continue; int msg1 = (bootstrap->getRank() + 1) * 3; int msg2 = (bootstrap->getRank() + 1) * 3 + 1; @@ -35,7 +35,7 @@ void test_sendrecv(std::shared_ptr bootstrap){ } for (int i = 0; i < bootstrap->getNranks(); i++) { - if (i == bootstrap->getRank()) + if (bootstrap->getRank() == i) continue; int msg1 = 0; int msg2 = 0;