integrate with new interfaces in mscclpp.hpp

This commit is contained in:
Changho Hwang
2023-04-25 11:47:58 +00:00
parent 8428b49858
commit 31f7897d5d
6 changed files with 130 additions and 113 deletions

View File

@@ -134,7 +134,7 @@ HEADERS := $(wildcard src/include/*.h)
CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*")
PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
INCEXPORTS := mscclpp.h mscclppfifo.h
INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp
INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%)
LIBNAME := libmscclpp.so
@@ -198,7 +198,7 @@ $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o: $(UTDIR)/%.cc $(HEADERS)
@mkdir -p $(@D)
$(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $<
$(BUILDDIR)/$(INCDIR)/%.h: src/$(INCDIR)/%.h
$(BUILDDIR)/$(INCDIR)/%: src/$(INCDIR)/%
@mkdir -p $(@D)
cp $< $@
@@ -216,7 +216,7 @@ $(BUILDDIR)/$(BINDIR)/$(UTDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o $(LIBOBJTAR
# Compile .cc tests
$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cc $(INCTARGETS)
@mkdir -p $(@D)
$(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -Isrc/include -c $< $(MPI_MACRO)
$(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -c $< $(MPI_MACRO)
# Compile .cu tests
$(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cu $(INCTARGETS)

View File

@@ -1,6 +1,8 @@
#include "mscclpp.hpp"
#include "bootstrap.h"
#include "utils.h"
#include "checks.hpp"
#include "api.h"
#include <cstring>
#include <mutex>
@@ -11,6 +13,8 @@
#include <sys/resource.h>
#include <sys/types.h>
using namespace mscclpp;
namespace {
uint64_t hashUniqueId(const mscclppBootstrapHandle& id)
{
@@ -57,24 +61,33 @@ struct extInfo
mscclppSocketAddress extAddressListen;
};
class mscclppBootstrap::Impl
struct UniqueIdInternal
{
uint64_t magic;
union mscclppSocketAddress addr;
};
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId),
"UniqueIdInternal is too large to fit into UniqueId");
class DefaultBootstrap::Impl
{
public:
Impl(int rank, int nRanks);
~Impl();
void Initialize(const UniqueId uniqueId);
void Initialize(std::string ipPortPair);
void EstablishConnections();
UniqueId GetUniqueId();
void AllGather(void* allData, int size);
void Send(void* data, int size, int peer, int tag);
void Recv(void* data, int size, int peer, int tag);
void Barrier();
void Close();
void initialize(const UniqueId uniqueId);
void initialize(std::string ipPortPair);
void establishConnections();
UniqueId createUniqueId();
UniqueId getUniqueId() const;
void allGather(void* allData, int size);
void send(void* data, int size, int peer, int tag);
void recv(void* data, int size, int peer, int tag);
void barrier();
void close();
UniqueId uniqueId_;
private:
UniqueIdInternal uniqueId_;
int rank_;
int nRanks_;
bool netInitialized;
@@ -103,34 +116,38 @@ private:
// UniqueId MscclppBootstrap::Impl::uniqueId_;
mscclppBootstrap::Impl::Impl(int rank, int nRanks)
DefaultBootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()),
barrierArr_(nRanks, 0), abortFlag_(nullptr)
{
}
UniqueId mscclppBootstrap::Impl::GetUniqueId()
UniqueId DefaultBootstrap::Impl::getUniqueId() const
{
UniqueId ret;
std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_));
return ret;
}
UniqueId DefaultBootstrap::Impl::createUniqueId()
{
netInit("");
MSCCLPPTHROW(getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic)));
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress));
bootstrapCreateRoot();
return uniqueId_;
return getUniqueId();
}
void mscclppBootstrap::Impl::Initialize(const UniqueId uniqueId)
void DefaultBootstrap::Impl::initialize(const UniqueId uniqueId)
{
netInit("");
uniqueId_.magic = uniqueId.magic;
uniqueId_.addr = uniqueId.addr;
// printf("addr = %s port = %d\n", inet_ntoa(uniqueId_.addr.sin.sin_addr), (int)ntohs(uniqueId_.addr.sin.sin_port));
std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_));
EstablishConnections();
establishConnections();
}
void mscclppBootstrap::Impl::Initialize(std::string ipPortPair)
void DefaultBootstrap::Impl::initialize(std::string ipPortPair)
{
netInit(ipPortPair);
@@ -142,17 +159,17 @@ void mscclppBootstrap::Impl::Initialize(std::string ipPortPair)
bootstrapCreateRoot();
}
EstablishConnections();
establishConnections();
}
mscclppBootstrap::Impl::~Impl()
DefaultBootstrap::Impl::~Impl()
{
if (rootThread_.joinable()) {
rootThread_.join();
}
}
void mscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
void DefaultBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
std::vector<mscclppSocketAddress>& rankAddresses,
std::vector<mscclppSocketAddress>& rankAddressesRoot,
int& rank)
@@ -181,7 +198,7 @@ void mscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock,
rank = info.rank;
}
void mscclppBootstrap::Impl::sendHandleToPeer(int peer,
void DefaultBootstrap::Impl::sendHandleToPeer(int peer,
const std::vector<mscclppSocketAddress>& rankAddresses,
const std::vector<mscclppSocketAddress>& rankAddressesRoot)
{
@@ -193,7 +210,7 @@ void mscclppBootstrap::Impl::sendHandleToPeer(int peer,
MSCCLPPTHROW(mscclppSocketClose(&sock));
}
void mscclppBootstrap::Impl::bootstrapCreateRoot()
void DefaultBootstrap::Impl::bootstrapCreateRoot()
{
mscclppSocket listenSock;
@@ -216,7 +233,7 @@ void mscclppBootstrap::Impl::bootstrapCreateRoot()
rootThread_ = std::thread(lambda);
}
void mscclppBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock)
void DefaultBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock)
{
int numCollected = 0;
std::vector<mscclppSocketAddress> rankAddresses(this->nRanks_, mscclppSocketAddress());
@@ -245,7 +262,7 @@ void mscclppBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock)
TRACE(MSCCLPP_INIT, "DONE");
}
void mscclppBootstrap::Impl::netInit(std::string ipPortPair)
void DefaultBootstrap::Impl::netInit(std::string ipPortPair)
{
if (netInitialized)
return;
@@ -271,7 +288,7 @@ void mscclppBootstrap::Impl::netInit(std::string ipPortPair)
netInitialized = true;
}
void mscclppBootstrap::Impl::EstablishConnections()
void DefaultBootstrap::Impl::establishConnections()
{
mscclppSocketAddress nextAddr;
mscclppSocket sock, listenSockRoot;
@@ -332,12 +349,12 @@ void mscclppBootstrap::Impl::EstablishConnections()
// AllGather all listen handlers
MSCCLPPTHROW(mscclppSocketGetAddr(&this->listenSock_, &this->peerCommAddresses_[rank_]));
AllGather(this->peerCommAddresses_.data(), sizeof(mscclppSocketAddress));
allGather(this->peerCommAddresses_.data(), sizeof(mscclppSocketAddress));
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
}
void mscclppBootstrap::Impl::AllGather(void* allData, int size)
void DefaultBootstrap::Impl::allGather(void* allData, int size)
{
char* data = static_cast<char*>(allData);
int rank = this->rank_;
@@ -362,13 +379,13 @@ void mscclppBootstrap::Impl::AllGather(void* allData, int size)
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
}
void mscclppBootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size)
void DefaultBootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size)
{
MSCCLPPTHROW(mscclppSocketSend(sock, &size, sizeof(int)));
MSCCLPPTHROW(mscclppSocketSend(sock, const_cast<void*>(data), size));
}
void mscclppBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
void DefaultBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
{
int recvSize;
MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int)));
@@ -378,7 +395,7 @@ void mscclppBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size)
MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size)));
}
void mscclppBootstrap::Impl::Send(void* data, int size, int peer, int tag)
void DefaultBootstrap::Impl::send(void* data, int size, int peer, int tag)
{
mscclppSocket sock;
MSCCLPPTHROW(mscclppSocketInit(&sock, &this->peerCommAddresses_[peer], this->uniqueId_.magic,
@@ -391,7 +408,7 @@ void mscclppBootstrap::Impl::Send(void* data, int size, int peer, int tag)
MSCCLPPTHROW(mscclppSocketClose(&sock));
}
void mscclppBootstrap::Impl::Recv(void* data, int size, int peer, int tag)
void DefaultBootstrap::Impl::recv(void* data, int size, int peer, int tag)
{
// search over all unexpected messages
for (auto it = unexpectedMessages_.begin(); it != unexpectedMessages_.end(); ++it){
@@ -421,62 +438,67 @@ void mscclppBootstrap::Impl::Recv(void* data, int size, int peer, int tag)
}
}
void mscclppBootstrap::Impl::Barrier()
void DefaultBootstrap::Impl::barrier()
{
AllGather(barrierArr_.data(), sizeof(int));
allGather(barrierArr_.data(), sizeof(int));
}
void mscclppBootstrap::Impl::Close()
void DefaultBootstrap::Impl::close()
{
MSCCLPPTHROW(mscclppSocketClose(&this->listenSock_));
MSCCLPPTHROW(mscclppSocketClose(&this->ringSendSocket_));
MSCCLPPTHROW(mscclppSocketClose(&this->ringRecvSocket_));
}
mscclppBootstrap::mscclppBootstrap(int rank, int nRanks)
MSCCLPP_API_CPP DefaultBootstrap::DefaultBootstrap(int rank, int nRanks)
{
// pimpl_ = std::make_unique<Impl>(ipPortPair, rank, nRanks, uniqueId);
pimpl_ = std::make_unique<Impl>(rank, nRanks);
}
UniqueId mscclppBootstrap::GetUniqueId()
MSCCLPP_API_CPP UniqueId DefaultBootstrap::createUniqueId()
{
return pimpl_->GetUniqueId();
return pimpl_->createUniqueId();
}
void mscclppBootstrap::Send(void* data, int size, int peer, int tag)
MSCCLPP_API_CPP UniqueId DefaultBootstrap::getUniqueId() const
{
pimpl_->Send(data, size, peer, tag);
return pimpl_->getUniqueId();
}
void mscclppBootstrap::Recv(void* data, int size, int peer, int tag)
MSCCLPP_API_CPP void DefaultBootstrap::send(void* data, int size, int peer, int tag)
{
pimpl_->Recv(data, size, peer, tag);
pimpl_->send(data, size, peer, tag);
}
void mscclppBootstrap::AllGather(void* allData, int size)
MSCCLPP_API_CPP void DefaultBootstrap::recv(void* data, int size, int peer, int tag)
{
pimpl_->AllGather(allData, size);
pimpl_->recv(data, size, peer, tag);
}
void mscclppBootstrap::Initialize(UniqueId uniqueId)
MSCCLPP_API_CPP void DefaultBootstrap::allGather(void* allData, int size)
{
pimpl_->Initialize(uniqueId);
pimpl_->allGather(allData, size);
}
void mscclppBootstrap::Initialize(std::string ipPortPair)
MSCCLPP_API_CPP void DefaultBootstrap::initialize(UniqueId uniqueId)
{
pimpl_->Initialize(ipPortPair);
pimpl_->initialize(uniqueId);
}
void mscclppBootstrap::Barrier()
MSCCLPP_API_CPP void DefaultBootstrap::initialize(std::string ipPortPair)
{
pimpl_->Barrier();
pimpl_->initialize(ipPortPair);
}
mscclppBootstrap::~mscclppBootstrap()
MSCCLPP_API_CPP void DefaultBootstrap::barrier()
{
pimpl_->Close();
pimpl_->barrier();
}
MSCCLPP_API_CPP DefaultBootstrap::~DefaultBootstrap()
{
pimpl_->close();
}
// ------------------- Old bootstrap functions -------------------

View File

@@ -5,35 +5,6 @@
#include "comm.h"
struct UniqueId
{
uint64_t magic;
union mscclppSocketAddress addr;
};
static_assert(sizeof(UniqueId) <= sizeof(mscclppUniqueId),
"Bootstrap handle is too large to fit inside MSCCLPP unique ID");
class __attribute__((visibility("default"))) mscclppBootstrap : public Bootstrap
{
public:
mscclppBootstrap(int rank, int nRanks);
~mscclppBootstrap();
UniqueId GetUniqueId();
void Initialize(UniqueId uniqueId);
void Initialize(std::string ipPortPair);
void Send(void* data, int size, int peer, int tag) override;
void Recv(void* data, int size, int peer, int tag) override;
void AllGather(void* allData, int size) override;
void Barrier() override;
private:
class Impl;
std::unique_ptr<Impl> pimpl_;
};
// ------------------- Old bootstrap headers: to be removed -------------------
struct mscclppBootstrapHandle

View File

@@ -248,16 +248,6 @@ typedef enum
} mscclppResult_t;
class Bootstrap {
public:
Bootstrap(){};
virtual ~Bootstrap() = default;
virtual void Send(void* data, int size, int peer, int tag) = 0;
virtual void Recv(void* data, int size, int peer, int tag) = 0;
virtual void AllGather(void* allData, int size) = 0;
virtual void Barrier() = 0;
};
/* Create a unique ID for communication. Only needs to be called by one process.
* Use with mscclppCommInitRankFromId().
* All processes need to provide the same ID to mscclppCommInitRankFromId().

View File

@@ -13,12 +13,51 @@
#include <vector>
#include <memory>
#include <string>
#include <functional>
#include <mscclppfifo.hpp>
namespace mscclpp {
#define MSCCLPP_UNIQUE_ID_BYTES 128
struct UniqueId {
char internal[MSCCLPP_UNIQUE_ID_BYTES];
};
class Bootstrap
{
public:
Bootstrap(){};
virtual ~Bootstrap() = default;
virtual void send(void* data, int size, int peer, int tag) = 0;
virtual void recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
virtual void barrier() = 0;
};
class DefaultBootstrap : public Bootstrap
{
public:
DefaultBootstrap(int rank, int nRanks);
~DefaultBootstrap();
UniqueId createUniqueId();
UniqueId getUniqueId() const;
void initialize(UniqueId uniqueId);
void initialize(std::string ipPortPair);
void send(void* data, int size, int peer, int tag) override;
void recv(void* data, int size, int peer, int tag) override;
void allGather(void* allData, int size) override;
void barrier() override;
private:
class Impl;
std::unique_ptr<Impl> pimpl_;
};
struct alignas(16) SignalEpochId {
// every signal(), increaments this and either:
// 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy
@@ -381,11 +420,6 @@ struct SimpleDeviceConnection {
BufferHandle src;
};
#define MSCCLPP_UNIQUE_ID_BYTES 128
struct UniqueId {
char internal[MSCCLPP_UNIQUE_ID_BYTES];
};
/* Create a unique ID for communication. Only needs to be called by one process.
* Use with mscclppCommInitRankFromId().
* All processes need to provide the same ID to mscclppCommInitRankFromId().

View File

@@ -1,4 +1,4 @@
#include "bootstrap.h"
#include "mscclpp.hpp"
#include <memory>
@@ -11,24 +11,24 @@ int main()
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
std::shared_ptr<mscclppBootstrap> bootstrap(new mscclppBootstrap(rank, worldSize));
std::shared_ptr<mscclpp::DefaultBootstrap> bootstrap(new mscclpp::DefaultBootstrap(rank, worldSize));
// bootstrap->Initialize("costsim-dev-00000A:50000");
UniqueId id;
mscclpp::UniqueId id;
if (rank == 0)
id = bootstrap->GetUniqueId();
id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->Initialize(id);
bootstrap->initialize(id);
std::vector<int> tmp(worldSize, 0);
tmp[rank] = rank+1;
bootstrap->AllGather(tmp.data(), sizeof(int));
bootstrap->allGather(tmp.data(), sizeof(int));
for (int i = 0; i < worldSize; i++){
if (tmp[i] != i+1)
printf("error AllGather: rank %d: tmp[%d] = %d\n", rank, i, tmp[i]);
}
printf("rank %d: AllGather test passed!\n", rank);
bootstrap->Barrier();
bootstrap->barrier();
printf("rank %d: Barrier test passed!\n", rank);
for (int i = 0; i < worldSize; i++){
@@ -36,8 +36,8 @@ int main()
continue;
int msg1 = (rank + 1)*2;
int msg2 = (rank + 1)*2+1;
bootstrap->Send(&msg1, sizeof(int), i, 0);
bootstrap->Send(&msg2, sizeof(int), i, 1);
bootstrap->send(&msg1, sizeof(int), i, 0);
bootstrap->send(&msg2, sizeof(int), i, 1);
}
for (int i = 0; i < worldSize; i++){
@@ -46,8 +46,8 @@ int main()
int msg1 = 0;
int msg2 = 0;
// recv them in the opposite order to check correctness
bootstrap->Recv(&msg2, sizeof(int), i, 1);
bootstrap->Recv(&msg1, sizeof(int), i, 0);
bootstrap->recv(&msg2, sizeof(int), i, 1);
bootstrap->recv(&msg1, sizeof(int), i, 0);
if (msg1 != (i+1)*2 || msg2 != (i+1)*2+1)
printf("error Send/Recv: rank %d: msg1 = %d, msg2 = %d\n", rank, msg1, msg2);
}