diff --git a/Makefile b/Makefile index 881296f4..41004ce3 100644 --- a/Makefile +++ b/Makefile @@ -134,7 +134,7 @@ HEADERS := $(wildcard src/include/*.h) CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*") PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)') -INCEXPORTS := mscclpp.h mscclppfifo.h +INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%) LIBNAME := libmscclpp.so @@ -198,7 +198,7 @@ $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o: $(UTDIR)/%.cc $(HEADERS) @mkdir -p $(@D) $(CXX) -o $@ $(INCLUDE) $(CXXFLAGS) -c $< -$(BUILDDIR)/$(INCDIR)/%.h: src/$(INCDIR)/%.h +$(BUILDDIR)/$(INCDIR)/%: src/$(INCDIR)/% @mkdir -p $(@D) cp $< $@ @@ -216,7 +216,7 @@ $(BUILDDIR)/$(BINDIR)/$(UTDIR)/%: $(BUILDDIR)/$(OBJDIR)/$(UTDIR)/%.o $(LIBOBJTAR # Compile .cc tests $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cc $(INCTARGETS) @mkdir -p $(@D) - $(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -Isrc/include -c $< $(MPI_MACRO) + $(CXX) -o $@ -I$(BUILDDIR)/$(INCDIR) $(MPI_INC) $(CXXFLAGS) -c $< $(MPI_MACRO) # Compile .cu tests $(BUILDDIR)/$(OBJDIR)/$(TESTSDIR)/%.o: $(TESTSDIR)/%.cu $(INCTARGETS) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index fc9f0645..2447489e 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -1,6 +1,8 @@ +#include "mscclpp.hpp" #include "bootstrap.h" #include "utils.h" #include "checks.hpp" +#include "api.h" #include #include @@ -11,6 +13,8 @@ #include #include +using namespace mscclpp; + namespace { uint64_t hashUniqueId(const mscclppBootstrapHandle& id) { @@ -57,24 +61,33 @@ struct extInfo mscclppSocketAddress extAddressListen; }; -class mscclppBootstrap::Impl +struct UniqueIdInternal +{ + uint64_t magic; + union mscclppSocketAddress addr; +}; +static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), + "UniqueIdInternal is too large to fit into UniqueId"); + +class DefaultBootstrap::Impl { public: Impl(int rank, int nRanks); ~Impl(); - void Initialize(const UniqueId uniqueId); - void Initialize(std::string ipPortPair); - void EstablishConnections(); - UniqueId GetUniqueId(); - void AllGather(void* allData, int size); - void Send(void* data, int size, int peer, int tag); - void Recv(void* data, int size, int peer, int tag); - void Barrier(); - void Close(); + void initialize(const UniqueId uniqueId); + void initialize(std::string ipPortPair); + void establishConnections(); + UniqueId createUniqueId(); + UniqueId getUniqueId() const; + void allGather(void* allData, int size); + void send(void* data, int size, int peer, int tag); + void recv(void* data, int size, int peer, int tag); + void barrier(); + void close(); - UniqueId uniqueId_; private: + UniqueIdInternal uniqueId_; int rank_; int nRanks_; bool netInitialized; @@ -103,34 +116,38 @@ private: // UniqueId MscclppBootstrap::Impl::uniqueId_; -mscclppBootstrap::Impl::Impl(int rank, int nRanks) +DefaultBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), netInitialized(false), peerCommAddresses_(nRanks, mscclppSocketAddress()), barrierArr_(nRanks, 0), abortFlag_(nullptr) { } -UniqueId mscclppBootstrap::Impl::GetUniqueId() +UniqueId DefaultBootstrap::Impl::getUniqueId() const +{ + UniqueId ret; + std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_)); + return ret; +} + +UniqueId DefaultBootstrap::Impl::createUniqueId() { netInit(""); MSCCLPPTHROW(getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic))); std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(mscclppSocketAddress)); bootstrapCreateRoot(); - - return uniqueId_; + return getUniqueId(); } -void mscclppBootstrap::Impl::Initialize(const UniqueId uniqueId) +void DefaultBootstrap::Impl::initialize(const UniqueId uniqueId) { netInit(""); - uniqueId_.magic = uniqueId.magic; - uniqueId_.addr = uniqueId.addr; - // printf("addr = %s port = %d\n", inet_ntoa(uniqueId_.addr.sin.sin_addr), (int)ntohs(uniqueId_.addr.sin.sin_port)); + std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); - EstablishConnections(); + establishConnections(); } -void mscclppBootstrap::Impl::Initialize(std::string ipPortPair) +void DefaultBootstrap::Impl::initialize(std::string ipPortPair) { netInit(ipPortPair); @@ -142,17 +159,17 @@ void mscclppBootstrap::Impl::Initialize(std::string ipPortPair) bootstrapCreateRoot(); } - EstablishConnections(); + establishConnections(); } -mscclppBootstrap::Impl::~Impl() +DefaultBootstrap::Impl::~Impl() { if (rootThread_.joinable()) { rootThread_.join(); } } -void mscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, +void DefaultBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector& rankAddresses, std::vector& rankAddressesRoot, int& rank) @@ -181,7 +198,7 @@ void mscclppBootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, rank = info.rank; } -void mscclppBootstrap::Impl::sendHandleToPeer(int peer, +void DefaultBootstrap::Impl::sendHandleToPeer(int peer, const std::vector& rankAddresses, const std::vector& rankAddressesRoot) { @@ -193,7 +210,7 @@ void mscclppBootstrap::Impl::sendHandleToPeer(int peer, MSCCLPPTHROW(mscclppSocketClose(&sock)); } -void mscclppBootstrap::Impl::bootstrapCreateRoot() +void DefaultBootstrap::Impl::bootstrapCreateRoot() { mscclppSocket listenSock; @@ -216,7 +233,7 @@ void mscclppBootstrap::Impl::bootstrapCreateRoot() rootThread_ = std::thread(lambda); } -void mscclppBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock) +void DefaultBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock) { int numCollected = 0; std::vector rankAddresses(this->nRanks_, mscclppSocketAddress()); @@ -245,7 +262,7 @@ void mscclppBootstrap::Impl::bootstrapRoot(mscclppSocket listenSock) TRACE(MSCCLPP_INIT, "DONE"); } -void mscclppBootstrap::Impl::netInit(std::string ipPortPair) +void DefaultBootstrap::Impl::netInit(std::string ipPortPair) { if (netInitialized) return; @@ -271,7 +288,7 @@ void mscclppBootstrap::Impl::netInit(std::string ipPortPair) netInitialized = true; } -void mscclppBootstrap::Impl::EstablishConnections() +void DefaultBootstrap::Impl::establishConnections() { mscclppSocketAddress nextAddr; mscclppSocket sock, listenSockRoot; @@ -332,12 +349,12 @@ void mscclppBootstrap::Impl::EstablishConnections() // AllGather all listen handlers MSCCLPPTHROW(mscclppSocketGetAddr(&this->listenSock_, &this->peerCommAddresses_[rank_])); - AllGather(this->peerCommAddresses_.data(), sizeof(mscclppSocketAddress)); + allGather(this->peerCommAddresses_.data(), sizeof(mscclppSocketAddress)); TRACE(MSCCLPP_INIT, "rank %d nranks %d - DONE", rank_, nRanks_); } -void mscclppBootstrap::Impl::AllGather(void* allData, int size) +void DefaultBootstrap::Impl::allGather(void* allData, int size) { char* data = static_cast(allData); int rank = this->rank_; @@ -362,13 +379,13 @@ void mscclppBootstrap::Impl::AllGather(void* allData, int size) TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size); } -void mscclppBootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size) +void DefaultBootstrap::Impl::netSend(mscclppSocket* sock, const void* data, int size) { MSCCLPPTHROW(mscclppSocketSend(sock, &size, sizeof(int))); MSCCLPPTHROW(mscclppSocketSend(sock, const_cast(data), size)); } -void mscclppBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) +void DefaultBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) { int recvSize; MSCCLPPTHROW(mscclppSocketRecv(sock, &recvSize, sizeof(int))); @@ -378,7 +395,7 @@ void mscclppBootstrap::Impl::netRecv(mscclppSocket* sock, void* data, int size) MSCCLPPTHROW(mscclppSocketRecv(sock, data, std::min(recvSize, size))); } -void mscclppBootstrap::Impl::Send(void* data, int size, int peer, int tag) +void DefaultBootstrap::Impl::send(void* data, int size, int peer, int tag) { mscclppSocket sock; MSCCLPPTHROW(mscclppSocketInit(&sock, &this->peerCommAddresses_[peer], this->uniqueId_.magic, @@ -391,7 +408,7 @@ void mscclppBootstrap::Impl::Send(void* data, int size, int peer, int tag) MSCCLPPTHROW(mscclppSocketClose(&sock)); } -void mscclppBootstrap::Impl::Recv(void* data, int size, int peer, int tag) +void DefaultBootstrap::Impl::recv(void* data, int size, int peer, int tag) { // search over all unexpected messages for (auto it = unexpectedMessages_.begin(); it != unexpectedMessages_.end(); ++it){ @@ -421,62 +438,67 @@ void mscclppBootstrap::Impl::Recv(void* data, int size, int peer, int tag) } } -void mscclppBootstrap::Impl::Barrier() +void DefaultBootstrap::Impl::barrier() { - AllGather(barrierArr_.data(), sizeof(int)); + allGather(barrierArr_.data(), sizeof(int)); } -void mscclppBootstrap::Impl::Close() +void DefaultBootstrap::Impl::close() { MSCCLPPTHROW(mscclppSocketClose(&this->listenSock_)); MSCCLPPTHROW(mscclppSocketClose(&this->ringSendSocket_)); MSCCLPPTHROW(mscclppSocketClose(&this->ringRecvSocket_)); } -mscclppBootstrap::mscclppBootstrap(int rank, int nRanks) +MSCCLPP_API_CPP DefaultBootstrap::DefaultBootstrap(int rank, int nRanks) { // pimpl_ = std::make_unique(ipPortPair, rank, nRanks, uniqueId); pimpl_ = std::make_unique(rank, nRanks); } -UniqueId mscclppBootstrap::GetUniqueId() +MSCCLPP_API_CPP UniqueId DefaultBootstrap::createUniqueId() { - return pimpl_->GetUniqueId(); + return pimpl_->createUniqueId(); } -void mscclppBootstrap::Send(void* data, int size, int peer, int tag) +MSCCLPP_API_CPP UniqueId DefaultBootstrap::getUniqueId() const { - pimpl_->Send(data, size, peer, tag); + return pimpl_->getUniqueId(); } -void mscclppBootstrap::Recv(void* data, int size, int peer, int tag) +MSCCLPP_API_CPP void DefaultBootstrap::send(void* data, int size, int peer, int tag) { - pimpl_->Recv(data, size, peer, tag); + pimpl_->send(data, size, peer, tag); } -void mscclppBootstrap::AllGather(void* allData, int size) +MSCCLPP_API_CPP void DefaultBootstrap::recv(void* data, int size, int peer, int tag) { - pimpl_->AllGather(allData, size); + pimpl_->recv(data, size, peer, tag); } -void mscclppBootstrap::Initialize(UniqueId uniqueId) +MSCCLPP_API_CPP void DefaultBootstrap::allGather(void* allData, int size) { - pimpl_->Initialize(uniqueId); + pimpl_->allGather(allData, size); } -void mscclppBootstrap::Initialize(std::string ipPortPair) +MSCCLPP_API_CPP void DefaultBootstrap::initialize(UniqueId uniqueId) { - pimpl_->Initialize(ipPortPair); + pimpl_->initialize(uniqueId); } -void mscclppBootstrap::Barrier() +MSCCLPP_API_CPP void DefaultBootstrap::initialize(std::string ipPortPair) { - pimpl_->Barrier(); + pimpl_->initialize(ipPortPair); } -mscclppBootstrap::~mscclppBootstrap() +MSCCLPP_API_CPP void DefaultBootstrap::barrier() { - pimpl_->Close(); + pimpl_->barrier(); +} + +MSCCLPP_API_CPP DefaultBootstrap::~DefaultBootstrap() +{ + pimpl_->close(); } // ------------------- Old bootstrap functions ------------------- diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index 2a6b99ba..6bb20f81 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -5,35 +5,6 @@ #include "comm.h" -struct UniqueId -{ - uint64_t magic; - union mscclppSocketAddress addr; -}; - -static_assert(sizeof(UniqueId) <= sizeof(mscclppUniqueId), - "Bootstrap handle is too large to fit inside MSCCLPP unique ID"); - -class __attribute__((visibility("default"))) mscclppBootstrap : public Bootstrap -{ -public: - mscclppBootstrap(int rank, int nRanks); - ~mscclppBootstrap(); - - UniqueId GetUniqueId(); - - void Initialize(UniqueId uniqueId); - void Initialize(std::string ipPortPair); - void Send(void* data, int size, int peer, int tag) override; - void Recv(void* data, int size, int peer, int tag) override; - void AllGather(void* allData, int size) override; - void Barrier() override; - -private: - class Impl; - std::unique_ptr pimpl_; -}; - // ------------------- Old bootstrap headers: to be removed ------------------- struct mscclppBootstrapHandle diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index a9675d1e..6f96af10 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -248,16 +248,6 @@ typedef enum } mscclppResult_t; -class Bootstrap { -public: - Bootstrap(){}; - virtual ~Bootstrap() = default; - virtual void Send(void* data, int size, int peer, int tag) = 0; - virtual void Recv(void* data, int size, int peer, int tag) = 0; - virtual void AllGather(void* allData, int size) = 0; - virtual void Barrier() = 0; -}; - /* Create a unique ID for communication. Only needs to be called by one process. * Use with mscclppCommInitRankFromId(). * All processes need to provide the same ID to mscclppCommInitRankFromId(). diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index e41e94b8..6a7230bd 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -13,12 +13,51 @@ #include #include +#include #include #include namespace mscclpp { +#define MSCCLPP_UNIQUE_ID_BYTES 128 +struct UniqueId { + char internal[MSCCLPP_UNIQUE_ID_BYTES]; +}; + +class Bootstrap +{ +public: + Bootstrap(){}; + virtual ~Bootstrap() = default; + virtual void send(void* data, int size, int peer, int tag) = 0; + virtual void recv(void* data, int size, int peer, int tag) = 0; + virtual void allGather(void* allData, int size) = 0; + virtual void barrier() = 0; +}; + +class DefaultBootstrap : public Bootstrap +{ +public: + DefaultBootstrap(int rank, int nRanks); + ~DefaultBootstrap(); + + UniqueId createUniqueId(); + UniqueId getUniqueId() const; + + void initialize(UniqueId uniqueId); + void initialize(std::string ipPortPair); + void send(void* data, int size, int peer, int tag) override; + void recv(void* data, int size, int peer, int tag) override; + void allGather(void* allData, int size) override; + void barrier() override; + +private: + class Impl; + std::unique_ptr pimpl_; +}; + + struct alignas(16) SignalEpochId { // every signal(), increaments this and either: // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy @@ -381,11 +420,6 @@ struct SimpleDeviceConnection { BufferHandle src; }; -#define MSCCLPP_UNIQUE_ID_BYTES 128 -struct UniqueId { - char internal[MSCCLPP_UNIQUE_ID_BYTES]; -}; - /* Create a unique ID for communication. Only needs to be called by one process. * Use with mscclppCommInitRankFromId(). * All processes need to provide the same ID to mscclppCommInitRankFromId(). diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index c2ef61f0..534c4114 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -1,4 +1,4 @@ -#include "bootstrap.h" +#include "mscclpp.hpp" #include @@ -11,24 +11,24 @@ int main() MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - std::shared_ptr bootstrap(new mscclppBootstrap(rank, worldSize)); + std::shared_ptr bootstrap(new mscclpp::DefaultBootstrap(rank, worldSize)); // bootstrap->Initialize("costsim-dev-00000A:50000"); - UniqueId id; + mscclpp::UniqueId id; if (rank == 0) - id = bootstrap->GetUniqueId(); + id = bootstrap->createUniqueId(); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); - bootstrap->Initialize(id); + bootstrap->initialize(id); std::vector tmp(worldSize, 0); tmp[rank] = rank+1; - bootstrap->AllGather(tmp.data(), sizeof(int)); + bootstrap->allGather(tmp.data(), sizeof(int)); for (int i = 0; i < worldSize; i++){ if (tmp[i] != i+1) printf("error AllGather: rank %d: tmp[%d] = %d\n", rank, i, tmp[i]); } printf("rank %d: AllGather test passed!\n", rank); - bootstrap->Barrier(); + bootstrap->barrier(); printf("rank %d: Barrier test passed!\n", rank); for (int i = 0; i < worldSize; i++){ @@ -36,8 +36,8 @@ int main() continue; int msg1 = (rank + 1)*2; int msg2 = (rank + 1)*2+1; - bootstrap->Send(&msg1, sizeof(int), i, 0); - bootstrap->Send(&msg2, sizeof(int), i, 1); + bootstrap->send(&msg1, sizeof(int), i, 0); + bootstrap->send(&msg2, sizeof(int), i, 1); } for (int i = 0; i < worldSize; i++){ @@ -46,8 +46,8 @@ int main() int msg1 = 0; int msg2 = 0; // recv them in the opposite order to check correctness - bootstrap->Recv(&msg2, sizeof(int), i, 1); - bootstrap->Recv(&msg1, sizeof(int), i, 0); + bootstrap->recv(&msg2, sizeof(int), i, 1); + bootstrap->recv(&msg1, sizeof(int), i, 0); if (msg1 != (i+1)*2 || msg2 != (i+1)*2+1) printf("error Send/Recv: rank %d: msg1 = %d, msg2 = %d\n", rank, msg1, msg2); }