mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
epoch creation
This commit is contained in:
@@ -98,7 +98,7 @@ NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRa
|
||||
{
|
||||
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
|
||||
addSetup(memoryReceiver);
|
||||
return memoryReceiver->memoryPromise_.get_future();
|
||||
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connect(int remoteRank, int tag, Transport transport)
|
||||
|
||||
11
src/epoch.cc
11
src/epoch.cc
@@ -1,10 +1,11 @@
|
||||
#include "epoch.hpp"
|
||||
#include "checks.hpp"
|
||||
#include "alloc.h"
|
||||
#include "api.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection) : connection_(connection) {
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1));
|
||||
MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1));
|
||||
|
||||
@@ -13,12 +14,12 @@ Epoch::Epoch(Communicator& communicator, std::shared_ptr<Connection> connection)
|
||||
remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag());
|
||||
}
|
||||
|
||||
Epoch::~Epoch() {
|
||||
MSCCLPPTHROW(mscclppCudaFree(&device_.epochIds_));
|
||||
MSCCLPPTHROW(mscclppCudaFree(&device_.expectedInboundEpochId_));
|
||||
MSCCLPP_API_CPP Epoch::~Epoch() {
|
||||
mscclppCudaFree(device_.epochIds_);
|
||||
mscclppCudaFree(device_.expectedInboundEpochId_);
|
||||
}
|
||||
|
||||
void Epoch::signal() {
|
||||
MSCCLPP_API_CPP void Epoch::signal() {
|
||||
connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_));
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
if (res != mscclppSuccess && res != mscclppInProgress) { \
|
||||
throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \
|
||||
} \
|
||||
} while (0);
|
||||
} while (false)
|
||||
|
||||
#define CUDATHROW(cmd) \
|
||||
do { \
|
||||
|
||||
@@ -35,15 +35,16 @@ public:
|
||||
// TODO: move implementations of these helpers out of this header
|
||||
void send(const std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
send((void*)data.size(), sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag);
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag+1);
|
||||
}
|
||||
void recv(std::vector<char>& data, int peer, int tag)
|
||||
{
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
recv((void*)data.data(), data.size(), peer, tag);
|
||||
recv((void*)data.data(), data.size(), peer, tag+1);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mscclpp.hpp"
|
||||
#include "epoch.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda_runtime.h>
|
||||
@@ -88,7 +89,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Connection setup passed" << std::endl;
|
||||
|
||||
int numBuffers = 1000;
|
||||
int numBuffers = 1;
|
||||
std::vector<int*> devicePtr(numBuffers);
|
||||
int deviceBufferSize = 1024*1024;
|
||||
|
||||
@@ -105,6 +106,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode)
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl;
|
||||
|
||||
std::vector<std::unique_ptr<mscclpp::Epoch>> epochs;
|
||||
for (auto entry : connections) {
|
||||
auto& conn = entry.second;
|
||||
epochs.emplace_back(std::make_unique<mscclpp::Epoch>(*communicator, conn));
|
||||
}
|
||||
communicator->setup();
|
||||
bootstrap->barrier();
|
||||
if (bootstrap->getRank() == 0)
|
||||
std::cout << "Epochs are created" << std::endl;
|
||||
|
||||
assert((deviceBufferSize / sizeof(int)) % worldSize == 0);
|
||||
size_t writeSize = deviceBufferSize / worldSize;
|
||||
|
||||
Reference in New Issue
Block a user