diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 1f683303..6ccb3d07 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -69,17 +69,20 @@ std::shared_ptr getPeerMemoryHandle(cudaIpcMemHandle_t ipcHandle) { }; #if defined(__HIP_PLATFORM_AMD__) static std::unordered_map> peerMemoryHandleMap; - std::mutex mutex; + static std::mutex mutex; std::lock_guard lock(mutex); auto it = peerMemoryHandleMap.find(ipcHandle); if (it != peerMemoryHandleMap.end()) { if (auto ptr = it->second.lock()) { return ptr; } - throw mscclpp::Error("Failed to get peer memory handle, may already be closed", mscclpp::ErrorCode::InvalidUsage); } MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess)); - std::shared_ptr ptr = std::shared_ptr(addr, deleter); + std::shared_ptr ptr = std::shared_ptr(addr, [ipcHandle, deleter](void* p) { + deleter(p); + std::lock_guard lock(mutex); + peerMemoryHandleMap.erase(ipcHandle); + }); peerMemoryHandleMap[ipcHandle] = ptr; return ptr; #else @@ -294,6 +297,9 @@ RegisteredMemory::Impl::Impl(const std::vector::const_iterator& begin, #endif // !(CUDA_NVLS_API_AVAILABLE) } else if (getHostHash() == this->hostHash) { this->peerHandle = getPeerMemoryHandle(entry.cudaIpcBaseHandle); + if (!this->peerHandle) { + throw Error("Failed to open CUDA IPC handle, may already be closed", ErrorCode::InvalidUsage); + } this->data = static_cast(this->peerHandle.get()) + entry.cudaIpcOffsetFromBase; } }