mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 17:55:52 +00:00
wip
This commit is contained in:
@@ -59,20 +59,24 @@ struct extInfo
|
||||
class MscclppBootstrap::Impl
|
||||
{
|
||||
public:
|
||||
Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId);
|
||||
Impl(int rank, int nRanks);
|
||||
~Impl();
|
||||
mscclppResult_t initialize();
|
||||
void initialize(const UniqueId uniqueId);
|
||||
void initialize(std::string ipPortPair);
|
||||
mscclppResult_t establishConnections();
|
||||
UniqueId getUniqueId();
|
||||
mscclppResult_t allGather(void* allData, int size);
|
||||
mscclppResult_t send(void* data, int size, int peer, int tag);
|
||||
mscclppResult_t recv(void* data, int size, int peer, int tag);
|
||||
mscclppResult_t barrier();
|
||||
mscclppResult_t close();
|
||||
|
||||
static UniqueId uniqueId_;
|
||||
UniqueId uniqueId_;
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
int nRanks_;
|
||||
bool netInitialized;
|
||||
mscclppSocket listenSock_;
|
||||
mscclppSocket ringRecvSocket_;
|
||||
mscclppSocket ringSendSocket_;
|
||||
@@ -95,37 +99,67 @@ private:
|
||||
mscclppResult_t netInit(std::string ipPortPair);
|
||||
};
|
||||
|
||||
UniqueId MscclppBootstrap::Impl::uniqueId_;
|
||||
// UniqueId MscclppBootstrap::Impl::uniqueId_;
|
||||
|
||||
MscclppBootstrap::Impl::Impl(std::string ipPortPair, int rank, int nRanks, const UniqueId uniqueId)
|
||||
: rank_(rank), nRanks_(nRanks), peerCommAddresses_(nRanks, mscclppSocketAddress()),
|
||||
MscclppBootstrap::Impl::Impl(int rank, int nRanks)
|
||||
: rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()),
|
||||
peerProxyAddresses_(nRanks, mscclppSocketAddress()), abortFlag_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
UniqueId MscclppBootstrap::Impl::getUniqueId()
|
||||
{
|
||||
UniqueId uniqueId;
|
||||
auto ret = netInit("");
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("Failed to initialize network");
|
||||
}
|
||||
ret = getRandomData(&uniqueId.magic, sizeof(uniqueId.magic));
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("getting random data failed");
|
||||
}
|
||||
std::memcpy(&uniqueId.addr, &netIfAddr_, sizeof(union mscclppSocketAddress));
|
||||
|
||||
return uniqueId;
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Impl::initialize(const UniqueId uniqueId)
|
||||
{
|
||||
int ret = netInit("");
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("Failed to initialize network");
|
||||
}
|
||||
|
||||
uniqueId_.magic = uniqueId.magic;
|
||||
uniqueId_.addr = uniqueId.addr;
|
||||
|
||||
if (rank_ == 0) {
|
||||
rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this);
|
||||
}
|
||||
|
||||
ret = establishConnections();
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("Failed to establish connections");
|
||||
}
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Impl::initialize(std::string ipPortPair)
|
||||
{
|
||||
int ret = netInit(ipPortPair);
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("Failed to initialize network");
|
||||
}
|
||||
|
||||
UniqueId zeroId;
|
||||
std::memset(&zeroId, 0, sizeof(UniqueId));
|
||||
if (std::memcmp(&uniqueId, &zeroId, sizeof(UniqueId)) != 0) {
|
||||
uniqueId_.magic = uniqueId.magic;
|
||||
uniqueId_.addr = uniqueId.addr;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ipPortPair.empty()) {
|
||||
uniqueId_.magic = 0xdeadbeef;
|
||||
} else {
|
||||
mscclppResult_t ret = getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic));
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("getting random data failed");
|
||||
}
|
||||
}
|
||||
uniqueId_.magic = 0xdeadbeef;
|
||||
std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(union mscclppSocketAddress));
|
||||
if (rank_ == 0) {
|
||||
rootThread_ = std::thread(&MscclppBootstrap::Impl::bootstrapRoot, this);
|
||||
}
|
||||
|
||||
ret = establishConnections();
|
||||
if (ret != mscclppSuccess) {
|
||||
throw std::runtime_error("Failed to establish connections");
|
||||
}
|
||||
}
|
||||
|
||||
MscclppBootstrap::Impl::~Impl()
|
||||
@@ -145,33 +179,39 @@ mscclppResult_t MscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listen
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
|
||||
mscclppSocketAddress zero;
|
||||
printf("hh 0\n");
|
||||
std::memset(&zero, 0, sizeof(mscclppSocketAddress));
|
||||
res = mscclppSocketInit(&sock);
|
||||
if (res != mscclppSuccess) {
|
||||
WARN("Bootstrap Root : mscclppSocketInit failed");
|
||||
return res;
|
||||
}
|
||||
printf("hh 1\n");
|
||||
res = mscclppSocketAccept(&sock, listenSock);
|
||||
if (res != mscclppSuccess) {
|
||||
WARN("Bootstrap Root : mscclppSocketAccept failed");
|
||||
return res;
|
||||
}
|
||||
printf("hh 2\n");
|
||||
res = netRecv(&sock, &info, sizeof(info));
|
||||
if (res != mscclppSuccess) {
|
||||
WARN("Bootstrap Root : netRecv failed");
|
||||
return res;
|
||||
}
|
||||
printf("hh 3\n");
|
||||
res = mscclppSocketClose(&sock);
|
||||
if (res != mscclppSuccess) {
|
||||
WARN("Bootstrap Root : mscclppSocketClose failed");
|
||||
return res;
|
||||
}
|
||||
|
||||
printf("hh 4\n");
|
||||
if (this->nRanks_ != info.nRanks) {
|
||||
WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", this->nRanks_, info.nRanks);
|
||||
return res;
|
||||
}
|
||||
|
||||
printf("hh 5\n");
|
||||
if (std::memcmp(&zero, &rankAddressesRoot[info.rank], sizeof(mscclppSocketAddress)) != 0) {
|
||||
WARN("Bootstrap Root : rank %d of %d ranks has already checked in", info.rank, this->nRanks_);
|
||||
return res;
|
||||
@@ -216,6 +256,7 @@ mscclppResult_t MscclppBootstrap::Impl::sendHandleToPeer(int peer,
|
||||
|
||||
mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
|
||||
{
|
||||
printf("I am here0 magic %x\n", uniqueId_.magic);
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
int numCollected = 0;
|
||||
std::vector<mscclppSocketAddress> rankAddresses(this->nRanks_, mscclppSocketAddress());
|
||||
@@ -226,16 +267,20 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
|
||||
std::memset(rankAddressesRoot.data(), 0, sizeof(mscclppSocketAddress) * this->nRanks_);
|
||||
setFilesLimit();
|
||||
|
||||
printf("I am here1 %x\n", uniqueId_.magic);
|
||||
mscclppSocket listenSock;
|
||||
MSCCLPPCHECK(
|
||||
mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0));
|
||||
MSCCLPPCHECK(mscclppSocketListen(&listenSock));
|
||||
printf("I am here2\n");
|
||||
|
||||
TRACE(MSCCLPP_INIT, "BEGIN");
|
||||
printf("I am here3\n");
|
||||
/* Receive addresses from all ranks */
|
||||
do {
|
||||
int rank;
|
||||
res = getRemoteAddresses(&listenSock, rankAddresses, rankAddressesRoot, rank);
|
||||
printf("I am here4\n");
|
||||
if (res != mscclppSuccess) {
|
||||
WARN("Bootstrap Root : getRemoteAddresses failed");
|
||||
break;
|
||||
@@ -262,6 +307,8 @@ mscclppResult_t MscclppBootstrap::Impl::bootstrapRoot()
|
||||
|
||||
mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair)
|
||||
{
|
||||
if (netInitialized)
|
||||
return mscclppSuccess;
|
||||
if (!ipPortPair.empty()) {
|
||||
union mscclppSocketAddress remoteAddr;
|
||||
if (mscclppSocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()) != mscclppSuccess) {
|
||||
@@ -284,17 +331,18 @@ mscclppResult_t MscclppBootstrap::Impl::netInit(std::string ipPortPair)
|
||||
std::sprintf(line, " %s:", netIfName_);
|
||||
mscclppSocketToString(&netIfAddr_, line + strlen(line));
|
||||
INFO(MSCCLPP_INIT, "Bootstrap : Using%s", line);
|
||||
netInitialized = true;
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
mscclppResult_t MscclppBootstrap::Impl::initialize()
|
||||
mscclppResult_t MscclppBootstrap::Impl::establishConnections()
|
||||
{
|
||||
mscclppSocket* proxySocket;
|
||||
mscclppSocketAddress nextAddr;
|
||||
mscclppSocket sock, listenSockRoot;
|
||||
extInfo info;
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank, nranks);
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d", rank_, nRanks_);
|
||||
|
||||
info.rank = this->rank_;
|
||||
info.nRanks = this->nRanks_;
|
||||
@@ -322,11 +370,21 @@ mscclppResult_t MscclppBootstrap::Impl::initialize()
|
||||
randomSleep(this->rank_);
|
||||
}
|
||||
|
||||
|
||||
char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
|
||||
std::sprintf(line, " %s:", netIfName_);
|
||||
mscclppSocketToString(&this->uniqueId_.addr, line + strlen(line));
|
||||
|
||||
printf("tt 1 %s\n", line);
|
||||
// send info on my listening socket to root
|
||||
MSCCLPPCHECK(mscclppSocketInit(&sock, &this->uniqueId_.addr, magic, mscclppSocketTypeBootstrap, this->abortFlag_));
|
||||
printf("tt 2\n");
|
||||
MSCCLPPCHECK(mscclppSocketConnect(&sock));
|
||||
printf("tt 3\n");
|
||||
MSCCLPPCHECK(netSend(&sock, &info, sizeof(info)));
|
||||
printf("tt 4\n");
|
||||
MSCCLPPCHECK(mscclppSocketClose(&sock));
|
||||
printf("tt 5\n");
|
||||
|
||||
// get info on my "next" rank in the bootstrap ring from root
|
||||
MSCCLPPCHECK(mscclppSocketInit(&sock));
|
||||
@@ -353,7 +411,7 @@ mscclppResult_t MscclppBootstrap::Impl::initialize()
|
||||
MSCCLPPCHECK(mscclppSocketGetAddr(proxySocket, &this->peerProxyAddresses_[rank_]));
|
||||
MSCCLPPCHECK(allGather(this->peerProxyAddresses_.data(), sizeof(union mscclppSocketAddress)));
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank, nranks);
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_);
|
||||
|
||||
return mscclppSuccess;
|
||||
}
|
||||
@@ -380,7 +438,7 @@ mscclppResult_t MscclppBootstrap::Impl::allGather(void* allData, int size)
|
||||
MSCCLPPCHECK(netRecv(&this->ringRecvSocket_, data + rSlice * size, size));
|
||||
}
|
||||
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
@@ -432,23 +490,15 @@ mscclppResult_t MscclppBootstrap::Impl::close()
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
MscclppBootstrap::MscclppBootstrap(std::string ipPortPair, int rank, int nRanks)
|
||||
MscclppBootstrap::MscclppBootstrap(int rank, int nRanks)
|
||||
{
|
||||
UniqueId uniqueId;
|
||||
std::memset(&uniqueId, 0, sizeof(uniqueId));
|
||||
// pimpl_ = std::make_unique<Impl>(ipPortPair, rank, nRanks, uniqueId);
|
||||
pimpl_ = new Impl(ipPortPair, rank, nRanks, uniqueId);
|
||||
}
|
||||
|
||||
MscclppBootstrap::MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks)
|
||||
{
|
||||
pimpl_ = new Impl("", rank, nRanks, uniqueId);
|
||||
// pimpl_ = std::make_unique<Impl>("", rank, nRanks, uniqueId);
|
||||
pimpl_ = new Impl(rank, nRanks);
|
||||
}
|
||||
|
||||
UniqueId MscclppBootstrap::GetUniqueId()
|
||||
{
|
||||
return Impl::uniqueId_;
|
||||
return pimpl_->getUniqueId();
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Send(void* data, int size, int peer, int tag)
|
||||
@@ -475,12 +525,14 @@ void MscclppBootstrap::AllGather(void* allData, int size)
|
||||
}
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Initialize()
|
||||
void MscclppBootstrap::Initialize(const UniqueId uniqueId)
|
||||
{
|
||||
mscclppResult_t res = pimpl_->initialize();
|
||||
if (res != mscclppSuccess) {
|
||||
throw std::runtime_error("MscclppBootstrap::Initialize failed");
|
||||
}
|
||||
pimpl_->initialize(uniqueId);
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Initialize(std::string ipPortPair)
|
||||
{
|
||||
pimpl_->initialize(ipPortPair);
|
||||
}
|
||||
|
||||
void MscclppBootstrap::Barrier()
|
||||
|
||||
@@ -17,13 +17,13 @@ static_assert(sizeof(UniqueId) <= sizeof(mscclppUniqueId),
|
||||
class __attribute__((visibility("default"))) MscclppBootstrap : public Bootstrap
|
||||
{
|
||||
public:
|
||||
MscclppBootstrap(std::string ipPortPair, int rank, int nRanks);
|
||||
MscclppBootstrap(UniqueId uniqueId, int rank, int nRanks);
|
||||
MscclppBootstrap(int rank, int nRanks);
|
||||
~MscclppBootstrap() override = default;
|
||||
|
||||
static UniqueId GetUniqueId();
|
||||
UniqueId GetUniqueId();
|
||||
|
||||
void Initialize();
|
||||
void Initialize(const UniqueId uniqueId);
|
||||
void Initialize(std::string ipPortPair);
|
||||
void Send(void* data, int size, int peer, int tag) override;
|
||||
void Recv(void* data, int size, int peer, int tag) override;
|
||||
void AllGather(void* allData, int size) override;
|
||||
|
||||
@@ -11,7 +11,13 @@ int main()
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
|
||||
|
||||
std::shared_ptr<Bootstrap> bootstrap(new MscclppBootstrap("", rank, worldSize));
|
||||
std::shared_ptr<MscclppBootstrap> bootstrap(new MscclppBootstrap(rank, worldSize));
|
||||
// bootstrap->Initialize("costsim-dev-00000A:50000");
|
||||
UniqueId id;
|
||||
if (rank == 0)
|
||||
id = bootstrap->GetUniqueId();
|
||||
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
bootstrap->Initialize(id);
|
||||
// need to call initialization first
|
||||
|
||||
MPI_Finalize();
|
||||
|
||||
Reference in New Issue
Block a user