diff --git a/src/communicator.cc b/src/communicator.cc index 7af88c73..2507c175 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -98,7 +98,7 @@ NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRa { auto memoryReceiver = std::make_shared(remoteRank, tag); addSetup(memoryReceiver); - return memoryReceiver->memoryPromise_.get_future(); + return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) diff --git a/src/epoch.cc b/src/epoch.cc index a14191fd..9263fd1c 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -1,10 +1,11 @@ #include "epoch.hpp" #include "checks.hpp" #include "alloc.h" +#include "api.h" namespace mscclpp { -Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { +MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1)); MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1)); @@ -13,12 +14,12 @@ Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag()); } -Epoch::~Epoch() { - MSCCLPPTHROW(mscclppCudaFree(&device_.epochIds_)); - MSCCLPPTHROW(mscclppCudaFree(&device_.expectedInboundEpochId_)); +MSCCLPP_API_CPP Epoch::~Epoch() { + mscclppCudaFree(device_.epochIds_); + mscclppCudaFree(device_.expectedInboundEpochId_); } -void Epoch::signal() { +MSCCLPP_API_CPP void Epoch::signal() { connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); } diff --git a/src/include/checks.hpp b/src/include/checks.hpp index 69b222ee..6473c92f 100644 --- a/src/include/checks.hpp +++ b/src/include/checks.hpp @@ -17,7 +17,7 @@ if (res != mscclppSuccess && res != mscclppInProgress) { \ throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \ } \ - } while (0); + } while (false) #define CUDATHROW(cmd) \ do { \ diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 5186fbc2..4c26131c 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -35,15 +35,16 @@ public: // TODO: move implementations of these helpers out of this header void send(const std::vector& data, int peer, int tag) { - send((void*)data.size(), sizeof(size_t), peer, tag); - send((void*)data.data(), data.size(), peer, tag); + size_t size = data.size(); + send((void*)&size, sizeof(size_t), peer, tag); + send((void*)data.data(), data.size(), peer, tag+1); } void recv(std::vector& data, int peer, int tag) { size_t size; recv((void*)&size, sizeof(size_t), peer, tag); data.resize(size); - recv((void*)data.data(), data.size(), peer, tag); + recv((void*)data.data(), data.size(), peer, tag+1); } }; diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index c922eaae..29712cd0 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -1,4 +1,5 @@ #include "mscclpp.hpp" +#include "epoch.hpp" #include #include @@ -88,7 +89,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Connection setup passed" << std::endl; - int numBuffers = 1000; + int numBuffers = 1; std::vector devicePtr(numBuffers); int deviceBufferSize = 1024*1024; @@ -105,6 +106,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + std::vector> epochs; + for (auto entry : connections) { + auto& conn = entry.second; + epochs.emplace_back(std::make_unique(*communicator, conn)); + } + communicator->setup(); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Epochs are created" << std::endl; assert((deviceBufferSize / sizeof(int)) % worldSize == 0); size_t writeSize = deviceBufferSize / worldSize;