mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
more tests for bootstrap
This commit is contained in:
@@ -79,6 +79,8 @@ public:
|
||||
void establishConnections();
|
||||
UniqueId createUniqueId();
|
||||
UniqueId getUniqueId() const;
|
||||
int getRank();
|
||||
int getNranks();
|
||||
void allGather(void* allData, int size);
|
||||
void send(void* data, int size, int peer, int tag);
|
||||
void recv(void* data, int size, int peer, int tag);
|
||||
@@ -137,6 +139,16 @@ UniqueId Bootstrap::Impl::createUniqueId()
|
||||
return getUniqueId();
|
||||
}
|
||||
|
||||
int Bootstrap::Impl::getRank()
|
||||
{
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int Bootstrap::Impl::getNranks()
|
||||
{
|
||||
return nRanks_;
|
||||
}
|
||||
|
||||
void Bootstrap::Impl::initialize(const UniqueId uniqueId)
|
||||
{
|
||||
netInit("");
|
||||
@@ -455,6 +467,16 @@ MSCCLPP_API_CPP UniqueId Bootstrap::getUniqueId() const
|
||||
return pimpl_->getUniqueId();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getRank()
|
||||
{
|
||||
return pimpl_->getRank();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Bootstrap::getNranks()
|
||||
{
|
||||
return pimpl_->getNranks();
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::send(void* data, int size, int peer, int tag)
|
||||
{
|
||||
pimpl_->send(data, size, peer, tag);
|
||||
|
||||
@@ -30,6 +30,8 @@ class BaseBootstrap
|
||||
public:
|
||||
BaseBootstrap(){};
|
||||
virtual ~BaseBootstrap() = default;
|
||||
virtual int getRank() = 0;
|
||||
virtual int getNranks() = 0;
|
||||
virtual void send(void* data, int size, int peer, int tag) = 0;
|
||||
virtual void recv(void* data, int size, int peer, int tag) = 0;
|
||||
virtual void allGather(void* allData, int size) = 0;
|
||||
@@ -47,6 +49,8 @@ public:
|
||||
|
||||
void initialize(UniqueId uniqueId);
|
||||
void initialize(std::string ipPortPair);
|
||||
int getRank() override;
|
||||
int getNranks() override;
|
||||
void send(void* data, int size, int peer, int tag) override;
|
||||
void recv(void* data, int size, int peer, int tag) override;
|
||||
void allGather(void* allData, int size) override;
|
||||
|
||||
@@ -1,49 +1,41 @@
|
||||
#include "mscclpp.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
|
||||
int main()
|
||||
{
|
||||
int rank, worldSize;
|
||||
MPI_Init(NULL, NULL);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
|
||||
// bootstrap->Initialize("costsim-dev-00000A:50000");
|
||||
mscclpp::UniqueId id;
|
||||
if (rank == 0)
|
||||
id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->initialize(id);
|
||||
|
||||
std::vector<int> tmp(worldSize, 0);
|
||||
tmp[rank] = rank + 1;
|
||||
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
std::vector<int> tmp(bootstrap->getNranks(), 0);
|
||||
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
|
||||
bootstrap->allGather(tmp.data(), sizeof(int));
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (tmp[i] != i + 1)
|
||||
printf("error AllGather: rank %d: tmp[%d] = %d\n", rank, i, tmp[i]);
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
assert(tmp[i] == i + 1);
|
||||
}
|
||||
printf("rank %d: AllGather test passed!\n", rank);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "AllGather test passed!" << std::endl;
|
||||
}
|
||||
|
||||
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
bootstrap->barrier();
|
||||
printf("rank %d: Barrier test passed!\n", rank);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Barrier test passed!" << std::endl;
|
||||
}
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i == rank)
|
||||
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (bootstrap->getRank() == 0)
|
||||
continue;
|
||||
int msg1 = (rank + 1) * 3;
|
||||
int msg2 = (rank + 1) * 3 + 1;
|
||||
int msg3 = (rank + 1) * 3 + 2;
|
||||
int msg1 = (bootstrap->getRank() + 1) * 3;
|
||||
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
|
||||
int msg3 = (bootstrap->getRank() + 1) * 3 + 2;
|
||||
bootstrap->send(&msg1, sizeof(int), i, 0);
|
||||
bootstrap->send(&msg2, sizeof(int), i, 1);
|
||||
bootstrap->send(&msg3, sizeof(int), i, 2);
|
||||
}
|
||||
|
||||
for (int i = 0; i < worldSize; i++) {
|
||||
if (i == rank)
|
||||
for (int i = 0; i < bootstrap->getNranks(); i++) {
|
||||
if (i == bootstrap->getRank())
|
||||
continue;
|
||||
int msg1 = 0;
|
||||
int msg2 = 0;
|
||||
@@ -52,10 +44,92 @@ int main()
|
||||
bootstrap->recv(&msg2, sizeof(int), i, 1);
|
||||
bootstrap->recv(&msg3, sizeof(int), i, 2);
|
||||
bootstrap->recv(&msg1, sizeof(int), i, 0);
|
||||
if (msg1 != (i + 1) * 3 || msg2 != (i + 1) * 3 + 1 || msg3 != (i + 1) * 3 + 2)
|
||||
printf("error Send/Recv: rank %d: msg1 = %d, msg2 = %d\n", rank, msg1, msg2);
|
||||
assert(msg1 == (i + 1) * 3);
|
||||
assert(msg2 == (i + 1) * 3 + 1);
|
||||
assert(msg3 == (i + 1) * 3 + 2);
|
||||
}
|
||||
printf("rank %d: Send/Recv test passed!\n", rank);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Send/Recv test passed!" << std::endl;
|
||||
}
|
||||
|
||||
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap){
|
||||
test_allgather(bootstrap);
|
||||
test_barrier(bootstrap);
|
||||
// test_sendrecv(bootstrap);
|
||||
}
|
||||
|
||||
void test_mscclpp_bootstrap_with_id(int rank, int worldSize){
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
|
||||
mscclpp::UniqueId id;
|
||||
if (bootstrap->getRank() == 0)
|
||||
id = bootstrap->createUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->initialize(id);
|
||||
|
||||
test_all(bootstrap);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar){
|
||||
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
|
||||
bootstrap->initialize(ipPortPiar);
|
||||
|
||||
test_all(bootstrap);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
class MPIBootstrap : public mscclpp::BaseBootstrap {
|
||||
public:
|
||||
MPIBootstrap() : BaseBootstrap() {}
|
||||
int getRank() override {
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
return rank;
|
||||
}
|
||||
int getNranks() override {
|
||||
int worldSize;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
return worldSize;
|
||||
}
|
||||
void allGather(void *sendbuf, int size) override {
|
||||
MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD);
|
||||
}
|
||||
void barrier() override {
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
}
|
||||
void send(void *sendbuf, int size, int dest, int tag) override {
|
||||
MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
|
||||
}
|
||||
void recv(void *recvbuf, int size, int source, int tag) override {
|
||||
MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||
}
|
||||
};
|
||||
|
||||
void test_mpi_bootstrap(){
|
||||
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap(new MPIBootstrap());
|
||||
test_all(bootstrap);
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "--- MPI Bootstrap test passed! ---" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
int rank, worldSize;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
if (argc > 2){
|
||||
if (rank == 0)
|
||||
std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl;
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
test_mscclpp_bootstrap_with_id(rank, worldSize);
|
||||
if (argc == 2)
|
||||
test_mscclpp_bootstrap_with_ip_port_pair(rank, worldSize, argv[1]);
|
||||
test_mpi_bootstrap();
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user