Cleanup old files and functions (#86)

This commit is contained in:
Changho Hwang
2023-06-01 17:34:57 +08:00
committed by GitHub
parent 457c422791
commit 9cee6c4a74
41 changed files with 657 additions and 2030 deletions

View File

@@ -1,34 +1,28 @@
#include <mscclpp/core.hpp>
#include <mpi.h>
#include <cassert>
#include <iostream>
#include <memory>
#include <mpi.h>
#include <mscclpp/core.hpp>
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
void test_allgather(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
std::vector<int> tmp(bootstrap->getNranks(), 0);
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
bootstrap->allGather(tmp.data(), sizeof(int));
for (int i = 0; i < bootstrap->getNranks(); i++) {
assert(tmp[i] == i + 1);
}
if (bootstrap->getRank() == 0)
std::cout << "AllGather test passed!" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "AllGather test passed!" << std::endl;
}
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
void test_barrier(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
bootstrap->barrier();
if (bootstrap->getRank() == 0)
std::cout << "Barrier test passed!" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "Barrier test passed!" << std::endl;
}
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == i)
continue;
if (bootstrap->getRank() == i) continue;
int msg1 = (bootstrap->getRank() + 1) * 3;
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
int msg3 = (bootstrap->getRank() + 1) * 3 + 2;
@@ -38,8 +32,7 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
}
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == i)
continue;
if (bootstrap->getRank() == i) continue;
int msg1 = 0;
int msg2 = 0;
int msg3 = 0;
@@ -51,102 +44,79 @@ void test_sendrecv(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
assert(msg2 == (i + 1) * 3 + 1);
assert(msg3 == (i + 1) * 3 + 2);
}
if (bootstrap->getRank() == 0)
std::cout << "Send/Recv test passed!" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "Send/Recv test passed!" << std::endl;
}
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap)
{
void test_all(std::shared_ptr<mscclpp::BaseBootstrap> bootstrap) {
test_allgather(bootstrap);
test_barrier(bootstrap);
test_sendrecv(bootstrap);
}
void test_mscclpp_bootstrap_with_id(int rank, int worldSize)
{
void test_mscclpp_bootstrap_with_id(int rank, int worldSize) {
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0)
id = bootstrap->createUniqueId();
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
test_all(bootstrap);
if (bootstrap->getRank() == 0)
std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl;
}
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar)
{
void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPair) {
std::shared_ptr<mscclpp::Bootstrap> bootstrap(new mscclpp::Bootstrap(rank, worldSize));
bootstrap->initialize(ipPortPiar);
bootstrap->initialize(ipPortPair);
test_all(bootstrap);
if (bootstrap->getRank() == 0)
std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl;
}
class MPIBootstrap : public mscclpp::BaseBootstrap
{
public:
MPIBootstrap() : BaseBootstrap()
{
}
int getRank() override
{
class MPIBootstrap : public mscclpp::BaseBootstrap {
public:
MPIBootstrap() : BaseBootstrap() {}
int getRank() override {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
int getNranks() override
{
int getNranks() override {
int worldSize;
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
return worldSize;
}
void allGather(void* sendbuf, int size) override
{
void allGather(void* sendbuf, int size) override {
MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD);
}
void barrier() override
{
MPI_Barrier(MPI_COMM_WORLD);
}
void send(void* sendbuf, int size, int dest, int tag) override
{
void barrier() override { MPI_Barrier(MPI_COMM_WORLD); }
void send(void* sendbuf, int size, int dest, int tag) override {
MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
}
void recv(void* recvbuf, int size, int source, int tag) override
{
void recv(void* recvbuf, int size, int source, int tag) override {
MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
};
void test_mpi_bootstrap()
{
void test_mpi_bootstrap() {
std::shared_ptr<mscclpp::BaseBootstrap> bootstrap(new MPIBootstrap());
test_all(bootstrap);
if (bootstrap->getRank() == 0)
std::cout << "--- MPI Bootstrap test passed! ---" << std::endl;
if (bootstrap->getRank() == 0) std::cout << "--- MPI Bootstrap test passed! ---" << std::endl;
}
int main(int argc, char** argv)
{
int main(int argc, char** argv) {
int rank, worldSize;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
if (argc > 2) {
if (rank == 0)
std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl;
if (rank == 0) std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl;
MPI_Finalize();
return 0;
}
test_mscclpp_bootstrap_with_id(rank, worldSize);
if (argc == 2)
test_mscclpp_bootstrap_with_ip_port_pair(rank, worldSize, argv[1]);
if (argc == 2) test_mscclpp_bootstrap_with_ip_port_pair(rank, worldSize, argv[1]);
test_mpi_bootstrap();
MPI_Finalize();
return 0;
}
}