diff --git a/src/communicator.cc b/src/communicator.cc index 469502b7..1fd64132 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -86,9 +86,7 @@ struct MemoryReceiver : public Setuppable { std::vector data; bootstrap->recv(data, remoteRank_, tag_); - auto memory = RegisteredMemory::deserialize(data); - memory.data(); - memoryPromise_.set_value(memory); + memoryPromise_.set_value(RegisteredMemory::deserialize(data)); } std::promise memoryPromise_; diff --git a/src/connection.cc b/src/connection.cc index 0dee770b..dca3e662 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -55,7 +55,6 @@ Transport CudaIpcConnection::remoteTransport() void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - ScopedTimer timer("CudaIpcConnection::write"); validateTransport(dst, remoteTransport()); validateTransport(src, transport()); diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index e95507f1..bf4802ce 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -16,7 +16,10 @@ struct TransportInfo // TODO: rewrite this using std::variant or something bool ibLocal; union { - cudaIpcMemHandle_t cudaIpcHandle; + struct { + cudaIpcMemHandle_t cudaIpcBaseHandle; + size_t cudaIpcOffsetFromBase; + }; struct { const IbMr* ibMr; IbMrInfo ibMrInfo; @@ -27,9 +30,9 @@ struct TransportInfo struct RegisteredMemory::Impl { void* data; - bool dataInitialized; size_t size; int rank; + uint64_t hostHash; TransportFlags transports; std::vector transportInfos; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index abf17a8b..fed732a0 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,13 +1,14 @@ #include "registered_memory.hpp" #include "api.h" #include "checks.hpp" +#include "utils.h" #include #include namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : data(data), dataInitialized(true), size(size), rank(rank), transports(transports) + : data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -18,7 +19,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t size_t baseDataSize; // dummy CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); - transportInfo.cudaIpcHandle = handle; + // TODO: bug with offset of base? + transportInfo.cudaIpcBaseHandle = handle; + transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr; this->transportInfos.push_back(transportInfo); } if ((transports & AllIBTransports).any()) { @@ -57,24 +60,12 @@ MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; -void* RegisteredMemory::data() +MSCCLPP_API_CPP void* RegisteredMemory::data() { - if (!pimpl->dataInitialized) { - if (pimpl->transports.has(Transport::CudaIpc)) { - auto entry = pimpl->getTransportInfo(Transport::CudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", pimpl->data); - } - else - { - pimpl->data = nullptr; - } - pimpl->dataInitialized = true; - } return pimpl->data; } -size_t RegisteredMemory::size() +MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } @@ -84,7 +75,7 @@ MSCCLPP_API_CPP int RegisteredMemory::rank() return pimpl->rank; } -TransportFlags RegisteredMemory::transports() +MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; } @@ -94,6 +85,7 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); if (pimpl->transportInfos.size() > std::numeric_limits::max()) { throw std::runtime_error("Too many transport info entries"); @@ -103,7 +95,9 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { - std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), + std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase), std::back_inserter(result)); } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); @@ -126,6 +120,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) it += sizeof(this->size); std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); it += sizeof(this->rank); + std::copy_n(it, sizeof(this->hostHash), reinterpret_cast(&this->hostHash)); + it += sizeof(this->hostHash); std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); it += sizeof(this->transports); int8_t transportCount; @@ -136,15 +132,13 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); if (transportInfo.transport == Transport::CudaIpc) { - cudaIpcMemHandle_t handle; - std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); - it += sizeof(handle); - transportInfo.cudaIpcHandle = handle; + std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); + it += sizeof(transportInfo.cudaIpcBaseHandle); + std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); + it += sizeof(transportInfo.cudaIpcOffsetFromBase); } else if (AllIBTransports.has(transportInfo.transport)) { - IbMrInfo info; - std::copy_n(it, sizeof(info), reinterpret_cast(&info)); - it += sizeof(info); - transportInfo.ibMrInfo = info; + std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast(&transportInfo.ibMrInfo)); + it += sizeof(transportInfo.ibMrInfo); transportInfo.ibLocal = false; } else { throw std::runtime_error("Unknown transport"); @@ -155,7 +149,16 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) throw std::runtime_error("Deserialization failed"); } - dataInitialized = false; + if (transports.has(Transport::CudaIpc)) { + uint64_t localHostHash = getHostHash(); + if (localHostHash == this->hostHash) { + auto entry = getTransportInfo(Transport::CudaIpc); + void* base; + CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); + data = static_cast(base) + entry.cudaIpcOffsetFromBase; + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data); + } + } } } // namespace mscclpp diff --git a/src/utils.cc b/src/utils.cc index ebd31bfe..6954a64f 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -9,6 +9,7 @@ #include #include #include +#include // Get current Compute Capability // int mscclppCudaCompCap() { @@ -112,7 +113,7 @@ uint64_t getHash(const char* string, int n) * This string can be overridden by using the MSCCLPP_HOSTID env var. */ #define HOSTID_FILE "/proc/sys/kernel/random/boot_id" -uint64_t getHostHash(void) +uint64_t computeHostHash(void) { char hostHash[1024]; char* hostId; @@ -144,6 +145,12 @@ uint64_t getHostHash(void) return getHash(hostHash, strlen(hostHash)); } +uint64_t getHostHash(void) +{ + thread_local std::unique_ptr hostHash = std::make_unique(computeHostHash()); + return *hostHash; +} + /* Generate a hash of the unique identifying string for this process * that will be unique for both bare-metal and container instances * Equivalent of a hash of;