diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 1c37ff04..88c1005d 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -25,6 +25,7 @@ struct TransportInfo struct RegisteredMemory::Impl { void* data; + bool dataInitialized; size_t size; int rank; TransportFlags transports; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 516a4c64..470e7c10 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -7,7 +7,7 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : data(data), size(size), rank(rank), transports(transports) + : data(data), dataInitialized(true), size(size), rank(rank), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -57,6 +57,18 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; void* RegisteredMemory::data() { + if (!pimpl->dataInitialized) { + if (pimpl->transports.has(Transport::CudaIpc)) { + auto entry = pimpl->getTransportInfo(Transport::CudaIpc); + CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); + } + else + { + pimpl->data = nullptr; + } + pimpl->dataInitialized = true; + } return pimpl->data; } @@ -141,11 +153,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) throw std::runtime_error("Deserialization failed"); } - if (transports.has(Transport::CudaIpc)) { - auto entry = getTransportInfo(Transport::CudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); - } + dataInitialized = false; } } // namespace mscclpp