Lazy CUDA IPC handle opening

This commit is contained in:
Olli Saarikivi
2023-04-28 00:30:07 +00:00
parent 962e63b11a
commit fa0fcb470e
2 changed files with 15 additions and 6 deletions

View File

@@ -25,6 +25,7 @@ struct TransportInfo
struct RegisteredMemory::Impl
{
void* data;
bool dataInitialized;
size_t size;
int rank;
TransportFlags transports;

View File

@@ -7,7 +7,7 @@
namespace mscclpp {
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
: data(data), size(size), rank(rank), transports(transports)
: data(data), dataInitialized(true), size(size), rank(rank), transports(transports)
{
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
@@ -57,6 +57,18 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
void* RegisteredMemory::data()
{
if (!pimpl->dataInitialized) {
if (pimpl->transports.has(Transport::CudaIpc)) {
auto entry = pimpl->getTransportInfo(Transport::CudaIpc);
CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data);
}
else
{
pimpl->data = nullptr;
}
pimpl->dataInitialized = true;
}
return pimpl->data;
}
@@ -141,11 +153,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
throw std::runtime_error("Deserialization failed");
}
if (transports.has(Transport::CudaIpc)) {
auto entry = getTransportInfo(Transport::CudaIpc);
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data);
}
dataInitialized = false;
}
} // namespace mscclpp