|
|
|
|
@@ -1,13 +1,14 @@
|
|
|
|
|
#include "registered_memory.hpp"
|
|
|
|
|
#include "api.h"
|
|
|
|
|
#include "checks.hpp"
|
|
|
|
|
#include "utils.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cuda.h>
|
|
|
|
|
|
|
|
|
|
namespace mscclpp {
|
|
|
|
|
|
|
|
|
|
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
|
|
|
|
|
: data(data), dataInitialized(true), size(size), rank(rank), transports(transports)
|
|
|
|
|
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports)
|
|
|
|
|
{
|
|
|
|
|
if (transports.has(Transport::CudaIpc)) {
|
|
|
|
|
TransportInfo transportInfo;
|
|
|
|
|
@@ -18,7 +19,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t
|
|
|
|
|
size_t baseDataSize; // dummy
|
|
|
|
|
CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
|
|
|
|
|
CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
|
|
|
|
|
transportInfo.cudaIpcHandle = handle;
|
|
|
|
|
// TODO: bug with offset of base?
|
|
|
|
|
transportInfo.cudaIpcBaseHandle = handle;
|
|
|
|
|
transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr;
|
|
|
|
|
this->transportInfos.push_back(transportInfo);
|
|
|
|
|
}
|
|
|
|
|
if ((transports & AllIBTransports).any()) {
|
|
|
|
|
@@ -57,24 +60,12 @@ MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr<Impl> pimpl)
|
|
|
|
|
|
|
|
|
|
MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
|
|
|
|
|
|
|
|
|
|
void* RegisteredMemory::data()
|
|
|
|
|
MSCCLPP_API_CPP void* RegisteredMemory::data()
|
|
|
|
|
{
|
|
|
|
|
if (!pimpl->dataInitialized) {
|
|
|
|
|
if (pimpl->transports.has(Transport::CudaIpc)) {
|
|
|
|
|
auto entry = pimpl->getTransportInfo(Transport::CudaIpc);
|
|
|
|
|
CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
|
|
|
|
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", pimpl->data);
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
pimpl->data = nullptr;
|
|
|
|
|
}
|
|
|
|
|
pimpl->dataInitialized = true;
|
|
|
|
|
}
|
|
|
|
|
return pimpl->data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t RegisteredMemory::size()
|
|
|
|
|
MSCCLPP_API_CPP size_t RegisteredMemory::size()
|
|
|
|
|
{
|
|
|
|
|
return pimpl->size;
|
|
|
|
|
}
|
|
|
|
|
@@ -84,7 +75,7 @@ MSCCLPP_API_CPP int RegisteredMemory::rank()
|
|
|
|
|
return pimpl->rank;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TransportFlags RegisteredMemory::transports()
|
|
|
|
|
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports()
|
|
|
|
|
{
|
|
|
|
|
return pimpl->transports;
|
|
|
|
|
}
|
|
|
|
|
@@ -94,6 +85,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
|
|
|
|
std::vector<char> result;
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
|
|
|
|
|
if (pimpl->transportInfos.size() > std::numeric_limits<int8_t>::max()) {
|
|
|
|
|
throw std::runtime_error("Too many transport info entries");
|
|
|
|
|
@@ -103,7 +95,9 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize()
|
|
|
|
|
for (auto& entry : pimpl->transportInfos) {
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
|
|
|
|
|
if (entry.transport == Transport::CudaIpc) {
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle),
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
|
|
|
|
|
std::back_inserter(result));
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
|
|
|
|
|
std::back_inserter(result));
|
|
|
|
|
} else if (AllIBTransports.has(entry.transport)) {
|
|
|
|
|
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
|
|
|
|
|
@@ -126,6 +120,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
|
|
|
|
it += sizeof(this->size);
|
|
|
|
|
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
|
|
|
|
|
it += sizeof(this->rank);
|
|
|
|
|
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
|
|
|
|
|
it += sizeof(this->hostHash);
|
|
|
|
|
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
|
|
|
|
|
it += sizeof(this->transports);
|
|
|
|
|
int8_t transportCount;
|
|
|
|
|
@@ -136,15 +132,13 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
|
|
|
|
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
|
|
|
|
|
it += sizeof(transportInfo.transport);
|
|
|
|
|
if (transportInfo.transport == Transport::CudaIpc) {
|
|
|
|
|
cudaIpcMemHandle_t handle;
|
|
|
|
|
std::copy_n(it, sizeof(handle), reinterpret_cast<char*>(&handle));
|
|
|
|
|
it += sizeof(handle);
|
|
|
|
|
transportInfo.cudaIpcHandle = handle;
|
|
|
|
|
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
|
|
|
|
|
it += sizeof(transportInfo.cudaIpcBaseHandle);
|
|
|
|
|
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
|
|
|
|
|
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
|
|
|
|
|
} else if (AllIBTransports.has(transportInfo.transport)) {
|
|
|
|
|
IbMrInfo info;
|
|
|
|
|
std::copy_n(it, sizeof(info), reinterpret_cast<char*>(&info));
|
|
|
|
|
it += sizeof(info);
|
|
|
|
|
transportInfo.ibMrInfo = info;
|
|
|
|
|
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
|
|
|
|
|
it += sizeof(transportInfo.ibMrInfo);
|
|
|
|
|
transportInfo.ibLocal = false;
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error("Unknown transport");
|
|
|
|
|
@@ -155,7 +149,16 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
|
|
|
|
throw std::runtime_error("Deserialization failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dataInitialized = false;
|
|
|
|
|
if (transports.has(Transport::CudaIpc)) {
|
|
|
|
|
uint64_t localHostHash = getHostHash();
|
|
|
|
|
if (localHostHash == this->hostHash) {
|
|
|
|
|
auto entry = getTransportInfo(Transport::CudaIpc);
|
|
|
|
|
void* base;
|
|
|
|
|
CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
|
|
|
|
|
data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
|
|
|
|
|
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace mscclpp
|
|
|
|
|
|