This commit is contained in:
Saeed Maleki
2023-04-24 05:58:11 +00:00
parent 073460c341
commit a9cfb82fcb
3 changed files with 105 additions and 47 deletions

View File

@@ -59,20 +59,24 @@ struct extInfo
class MscclppBootstrap::Impl
{
public:
Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId);
Impl(int rank, int nRanks);
~Impl();
mscclppResult_t initialize();
void initialize(const UniqueId uniqueId);
void initialize(std::string ipPortPair);
mscclppResult_t establishConnections();
UniqueId getUniqueId();
mscclppResult_t allGather(void* allData, int size);
mscclppResult_t send(void* data, int size, int peer, int tag);
mscclppResult_t recv(void* data, int size, int peer, int tag);
mscclppResult_t barrier();
mscclppResult_t close();
static UniqueId uniqueId_;
UniqueId uniqueId_;
private:
int rank_;
int nRanks_;
bool netInitialized;
mscclppSocket listenSock_;
mscclppSocket ringRecvSocket_;
mscclppSocket ringSendSocket_;
@@ -95,37 +99,67 @@ private:
mscclppResult_t netInit(std::string ipPortPair);
};
UniqueId MscclppBootstrap::Impl::uniqueId_;
// UniqueId MscclppBootstrap::Impl::uniqueId_;
MscclppBootstrap::Impl::Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId)
: rank_(rank), nRanks_(nRanks), peerCommAddresses_(nRanks, mscclppSocketAddress()),
MscclppBootstrap::Impl::Impl(int rank, int nRanks)
: rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()),
peerProxyAddresses_(nRanks, mscclppSocketAddress()), abortFlag_(nullptr)
{
}
UniqueId MscclppBootstrap::Impl::getUniqueId()
{
UniqueId uniqueId;
auto ret = netInit("");
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to initialize network");
}
ret = getRandomData(&uniqueId.magic, sizeof(uniqueId.magic));
if (ret != mscclppSuccess) {
throw std::runtime_error("getting random data failed");
}
std::memcpy(&uniqueId.addr, &netIfAddr_, sizeof(union mscclppSocketAddress));
return uniqueId;
}
void MscclppBootstrap::Impl::initialize(const UniqueId uniqueId)
{
int ret = netInit("");
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to initialize network");
}
uniqueId_.magic = uniqueId.magic;
uniqueId_.addr = uniqueId.addr;
if (rank_ == 0) {
rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this);
}
ret = establishConnections();
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to establish connections");
}
}
void MscclppBootstrap::Impl::initialize(std::string ipPortPair)
{
int ret = netInit(ipPortPair);
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to initialize network");
}
UniqueId zeroId;
std::memset(&zeroId, 0, sizeof(UniqueId));
if (std::memcmp(&uniqueId, &zeroId, sizeof(UniqueId)) != 0) {
uniqueId_.magic = uniqueId.magic;
uniqueId_.addr = uniqueId.addr;
return;
}
if (!ipPortPair.empty()) {
uniqueId_.magic = 0xdeadbeef;
} else {
mscclppResult_t ret = getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
if (ret != mscclppSuccess) {
throw std::runtime_error("getting random data failed");
}
}
uniqueId_.magic = 0xdeadbeef;
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(union mscclppSocketAddress));
if (rank_ == 0) {
rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this);
}
ret = establishConnections();
if (ret != mscclppSuccess) {
throw std::runtime_error("Failed to establish connections");
}
}
MscclppBootstrap::Impl::~Impl()
@@ -145,33 +179,39 @@ mscclppResult_t MscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listen
mscclppResult_t res = mscclppSuccess;
mscclppSocketAddress zero;
printf("hh 0\n");
std::memset(&zero, 0, sizeof(mscclppSocketAddress));
res = mscclppSocketInit(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketInit failed");
return res;
}
printf("hh 1\n");
res = mscclppSocketAccept(&sock, listenSock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketAccept failed");
return res;
}
printf("hh 2\n");
res = netRecv(&sock, &info, sizeof(info));
if (res != mscclppSuccess) {
WARN("Bootstrap Root : netRecv failed");
return res;
}
printf("hh 3\n");
res = mscclppSocketClose(&sock);
if (res != mscclppSuccess) {
WARN("Bootstrap Root : mscclppSocketClose failed");
return res;
}
printf("hh 4\n");
if (this->nRanks_ != info.nRanks) {
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", this->nRanks_, info.nRanks);
return res;
}
printf("hh 5\n");
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, this->nRanks_);
return res;
@@ -216,6 +256,7 @@ mscclppResult_t MscclppBootstrap::Impl::sendHandleToPeer(int peer,
mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
{
printf("I am here0 magic %x\n", uniqueId_.magic);
mscclppResult_t res = mscclppSuccess;
int numCollected = 0;
std::vector<mscclppSocketAddress> rankAddresses(this->nRanks_, mscclppSocketAddress());
@@ -226,16 +267,20 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_);
setFilesLimit();
printf("I am here1 %x\n", uniqueId_.magic);
mscclppSocket listenSock;
MSCCLPPCHECK(
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
MSCCLPPCHECK(mscclppSocketListen(&listenSock));
printf("I am here2\n");
TRACE(MSCCLPP_INIT, "BEGIN");
printf("I am here3\n");
/* Receive addresses from all ranks */
do {
int rank;
res = getRemoteAddresses(&listenSock, rankAddresses, rankAddressesRoot, rank);
printf("I am here4\n");
if (res != mscclppSuccess) {
WARN("Bootstrap Root : getRemoteAddresses failed");
break;
@@ -262,6 +307,8 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair)
{
if (netInitialized)
return mscclppSuccess;
if (!ipPortPair.empty()) {
union mscclppSocketAddress remoteAddr;
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
@@ -284,17 +331,18 @@ mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair)
std::sprintf(line, " %s:", netIfName_);
mscclppSocketToString(&netIfAddr_, line + strlen(line));
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
netInitialized = true;
return mscclppSuccess;
}
mscclppResult_t MscclppBootstrap::Impl::initialize()
mscclppResult_t MscclppBootstrap::Impl::establishConnections()
{
mscclppSocket* proxySocket;
mscclppSocketAddress nextAddr;
mscclppSocket sock, listenSockRoot;
extInfo info;
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks);
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
info.rank = this->rank_;
info.nRanks = this->nRanks_;
@@ -322,11 +370,21 @@ mscclppResult_t MscclppBootstrap::Impl::initialize()
randomSleep(this->rank_);
}
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
std::sprintf(line, " %s:", netIfName_);
mscclppSocketToString(&this->uniqueId_.addr, line + strlen(line));
printf("tt 1 %s\n", line);
// send info on my listening socket to root
MSCCLPPCHECK(mscclppSocketInit(&sock, &this->uniqueId_.addr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
printf("tt 2\n");
MSCCLPPCHECK(mscclppSocketConnect(&sock));
printf("tt 3\n");
MSCCLPPCHECK(netSend(&sock, &info, sizeof(info)));
printf("tt 4\n");
MSCCLPPCHECK(mscclppSocketClose(&sock));
printf("tt 5\n");
// get info on my "next" rank in the bootstrap ring from root
MSCCLPPCHECK(mscclppSocketInit(&sock));
@@ -353,7 +411,7 @@ mscclppResult_t MscclppBootstrap::Impl::initialize()
MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, &this->peerProxyAddresses_[rank_]));
MSCCLPPCHECK(allGather(this->peerProxyAddresses_.data(), sizeof(union mscclppSocketAddress)));
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks);
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
return mscclppSuccess;
}
@@ -380,7 +438,7 @@ mscclppResult_t MscclppBootstrap::Impl::allGather(void* allData, int size)
MSCCLPPCHECK(netRecv(&this->ringRecvSocket_, data + rSlice * size, size));
}
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
return mscclppSuccess;
}
@@ -432,23 +490,15 @@ mscclppResult_t MscclppBootstrap::Impl::close()
return mscclppSuccess;
}
MscclppBootstrap::MscclppBootstrap(std::string ipPortPair, int rank, int nRanks)
MscclppBootstrap::MscclppBootstrap(int rank, int nRanks)
{
UniqueId uniqueId;
std::memset(&uniqueId, 0, sizeof(uniqueId));
// pimpl_ = std::make_unique<Impl>(ipPortPair, rank, nRanks, uniqueId);
pimpl_ = new Impl(ipPortPair, rank, nRanks, uniqueId);
}
MscclppBootstrap::MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks)
{
pimpl_ = new Impl("", rank, nRanks, uniqueId);
// pimpl_ = std::make_unique<Impl>("", rank, nRanks, uniqueId);
pimpl_ = new Impl(rank, nRanks);
}
UniqueId MscclppBootstrap::GetUniqueId()
{
return Impl::uniqueId_;
return pimpl_->getUniqueId();
}
void MscclppBootstrap::Send(void* data, int size, int peer, int tag)
@@ -475,12 +525,14 @@ void MscclppBootstrap::AllGather(void* allData, int size)
}
}
void MscclppBootstrap::Initialize()
void MscclppBootstrap::Initialize(const UniqueId uniqueId)
{
mscclppResult_t res = pimpl_->initialize();
if (res != mscclppSuccess) {
throw std::runtime_error("MscclppBootstrap::Initialize failed");
}
pimpl_->initialize(uniqueId);
}
void MscclppBootstrap::Initialize(std::string ipPortPair)
{
pimpl_->initialize(ipPortPair);
}
void MscclppBootstrap::Barrier()

View File

@@ -17,13 +17,13 @@ static_assert(sizeof(UniqueId) <= sizeof(mscclppUniqueId),
class __attribute__((visibility("default"))) MscclppBootstrap : public Bootstrap
{
public:
MscclppBootstrap(std::string ipPortPair, int rank, int nRanks);
MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks);
MscclppBootstrap(int rank, int nRanks);
~MscclppBootstrap() override = default;
static UniqueId GetUniqueId();
UniqueId GetUniqueId();
void Initialize();
void Initialize(const UniqueId uniqueId);
void Initialize(std::string ipPortPair);
void Send(void* data, int size, int peer, int tag) override;
void Recv(void* data, int size, int peer, int tag) override;
void AllGather(void* allData, int size) override;

View File

@@ -11,7 +11,13 @@ int main()
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
std::shared_ptr<Bootstrap> bootstrap(new MscclppBootstrap("", rank, worldSize));
std::shared_ptr<MscclppBootstrap> bootstrap(new MscclppBootstrap(rank, worldSize));
// bootstrap->Initialize("costsim-dev-00000A:50000");
UniqueId id;
if (rank == 0)
id = bootstrap->GetUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->Initialize(id);
// need to call initialization first
MPI_Finalize();