mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-26 08:01:00 +00:00
bootstrap to the communicator
This commit is contained in:
@@ -54,15 +54,7 @@ static mscclppTransport_t transportToCStyle(TransportFlags flags) {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique<Impl>()) {
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<BaseBootstrap> bootstrap) : pimpl(std::make_unique<Impl>(bootstrap)) {}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
@@ -100,16 +92,4 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() {
|
||||
}
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::rank() {
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::size() {
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -24,7 +24,7 @@ void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
|
||||
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (bootstrap->getRank() == 0)
|
||||
if (bootstrap->getRank() == i)
|
||||
continue;
|
||||
int msg1 = (bootstrap->getRank() + 1) * 3;
|
||||
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
|
||||
@@ -35,7 +35,7 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
}
|
||||
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (i == bootstrap->getRank())
|
||||
if (bootstrap->getRank() == i)
|
||||
continue;
|
||||
int msg1 = 0;
|
||||
int msg2 = 0;
|
||||
|
||||
Reference in New Issue
Block a user