mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Lazy CUDA IPC handle opening
This commit is contained in:
@@ -25,6 +25,7 @@ struct TransportInfo
|
||||
struct RegisteredMemory::Impl
|
||||
{
|
||||
void* data;
|
||||
bool dataInitialized;
|
||||
size_t size;
|
||||
int rank;
|
||||
TransportFlags transports;
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace mscclpp {
|
||||
|
||||
RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
|
||||
: data(data), size(size), rank(rank), transports(transports)
|
||||
: data(data), dataInitialized(true), size(size), rank(rank), transports(transports)
|
||||
{
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
TransportInfo transportInfo;
|
||||
@@ -57,6 +57,18 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
|
||||
|
||||
void* RegisteredMemory::data()
|
||||
{
|
||||
if (!pimpl->dataInitialized) {
|
||||
if (pimpl->transports.has(Transport::CudaIpc)) {
|
||||
auto entry = pimpl->getTransportInfo(Transport::CudaIpc);
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data);
|
||||
}
|
||||
else
|
||||
{
|
||||
pimpl->data = nullptr;
|
||||
}
|
||||
pimpl->dataInitialized = true;
|
||||
}
|
||||
return pimpl->data;
|
||||
}
|
||||
|
||||
@@ -141,11 +153,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization)
|
||||
throw std::runtime_error("Deserialization failed");
|
||||
}
|
||||
|
||||
if (transports.has(Transport::CudaIpc)) {
|
||||
auto entry = getTransportInfo(Transport::CudaIpc);
|
||||
CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess));
|
||||
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data);
|
||||
}
|
||||
dataInitialized = false;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
Reference in New Issue
Block a user