diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 9f049cf8..85702ea1 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -201,6 +201,7 @@ RegisteredMemory::Impl::Impl(const std::vector::const_iterator& begin, } // Next decide how to set this->data + this->data = nullptr; if (getHostHash() == this->hostHash && getPidHash() == this->pidHash) { // The memory is local to the process, so originalDataPtr is valid as is this->data = this->originalDataPtr; @@ -211,22 +212,32 @@ RegisteredMemory::Impl::Impl(const std::vector::const_iterator& begin, if (this->isCuMemMapAlloc) { #if (CUDA_NVLS_API_AVAILABLE) CUmemGenericAllocationHandle handle; - if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) { - MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsMemHandleType())); + if (getHostHash() != this->hostHash) { + // TODO: only open handle if in same MNNVL domain + CUresult err = cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsMemHandleType()); + if (err != CUDA_SUCCESS) { + INFO(MSCCLPP_P2P, "Failed to import shareable handle from host: 0x%lx, may not be in the same MNNVL domain", + hostHash); + return; + } } else { - int rootPidFd = syscall(SYS_pidfd_open, entry.rootPid, 0); - if (rootPidFd < 0) { - throw SysError("pidfd_open() failed", errno); + if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) { + MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsMemHandleType())); + } else { + int rootPidFd = syscall(SYS_pidfd_open, entry.rootPid, 0); + if (rootPidFd < 0) { + throw SysError("pidfd_open() failed", errno); + } + int fd = syscall(SYS_pidfd_getfd, rootPidFd, entry.fileDesc, 0); + if (fd < 0) { + throw SysError("pidfd_getfd() failed", errno); + } + INFO(MSCCLPP_P2P, "Get file descriptor %d from pidfd %d on peer 0x%lx", fd, rootPidFd, hostHash); + MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast(fd), + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(rootPidFd); + close(fd); } - int fd = syscall(SYS_pidfd_getfd, rootPidFd, entry.fileDesc, 0); - if (fd < 0) { - throw SysError("pidfd_getfd() failed", errno); - } - INFO(MSCCLPP_P2P, "Get file descriptor %d from pidfd %d on peer 0x%lx", fd, rootPidFd, hostHash); - MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast(fd), - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - close(rootPidFd); - close(fd); } size_t minGran = detail::getMulticastGranularity(this->baseDataSize, CU_MULTICAST_GRANULARITY_MINIMUM); size_t recommendedGran = @@ -240,14 +251,13 @@ RegisteredMemory::Impl::Impl(const std::vector::const_iterator& begin, throw Error("CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.", ErrorCode::InvalidUsage); #endif - } else { + } else if (getHostHash() == this->hostHash) { MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); this->data = static_cast(base) + entry.cudaIpcOffsetFromBase; } + } + if (this->data != nullptr) { INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", this->data); - } else { - // No valid data pointer can be set - this->data = nullptr; } } diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index a1d76539..c3b232f6 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -414,8 +414,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& m mscclpp::RegisteredMemory& localRegMemory = (outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem; // store memory to keep resource alive - inputMemory_ = inputBufRegMem; - outputMemory_ = outputBufRegMem; + inputMemories_.push_back(inputBufRegMem); + outputMemories_.push_back(outputBufRegMem); setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories); std::unordered_map>> memorySemaphores; @@ -498,8 +498,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& m (getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem); // store memory to keep resource alive scratchMemory_ = getPacketBufRegMem; - inputMemory_ = inputBufRegMem; - outputMemory_ = outputBufRegMem; + inputMemories_.push_back(inputBufRegMem); + outputMemories_.push_back(outputBufRegMem); setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories); diff --git a/test/mscclpp-test/common.hpp b/test/mscclpp-test/common.hpp index fcd2e254..c2725450 100644 --- a/test/mscclpp-test/common.hpp +++ b/test/mscclpp-test/common.hpp @@ -132,8 +132,8 @@ class BaseTestEngine { std::shared_ptr comm_; std::shared_ptr chanService_; mscclpp::RegisteredMemory scratchMemory_; - mscclpp::RegisteredMemory inputMemory_; - mscclpp::RegisteredMemory outputMemory_; + std::vector inputMemories_; + std::vector outputMemories_; cudaStream_t stream_; int error_; };