bootstrap to the communicator

This commit is contained in:
Saeed Maleki
2023-04-27 04:23:44 +00:00
parent 7913d90158
commit c24896b62f
2 changed files with 3 additions and 23 deletions

View File

@@ -54,15 +54,7 @@ static mscclppTransport_t transportToCStyle(TransportFlags flags) {
}
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
}
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap) : pimpl(std::make_unique<Impl>(bootstrap)) {}
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
mscclppBootstrapAllGather(pimpl->comm, data, size);
@@ -100,16 +92,4 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() {
}
}
MSCCLPP_API_CPP int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
MSCCLPP_API_CPP int Communicator::size() {
int result;
mscclppCommSize(pimpl->comm, &result);
return result;
}
} // namespace mscclpp

View File

@@ -24,7 +24,7 @@ void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == 0)
if (bootstrap->getRank() == i)
continue;
int msg1 = (bootstrap->getRank() + 1) * 3;
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
@@ -35,7 +35,7 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
}
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (i == bootstrap->getRank())
if (bootstrap->getRank() == i)
continue;
int msg1 = 0;
int msg2 = 0;