diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 42a03a8e..516a4c64 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -13,8 +13,11 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; cudaIpcMemHandle_t handle; - // TODO: translate data to a base pointer - CUDATHROW(cudaIpcGetMemHandle(&handle, data)); + + void* baseDataPtr; + size_t baseDataSize; // dummy + CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); + CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } @@ -72,7 +75,7 @@ TransportFlags RegisteredMemory::transports() return pimpl->transports; } -std::vector RegisteredMemory::serialize() +MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); @@ -97,7 +100,7 @@ std::vector RegisteredMemory::serialize() return result; } -RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) +MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { return RegisteredMemory(std::make_shared(data)); } @@ -140,10 +143,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) if (transports.has(Transport::CudaIpc)) { auto entry = getTransportInfo(Transport::CudaIpc); - void* baseDataPtr; - size_t baseDataSize; // dummy - CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); - CUDATHROW(cudaIpcOpenMemHandle(&baseDataPtr, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); } } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index a05c8981..7fccf57b 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -55,6 +55,25 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) CUDATHROW(cudaMalloc(&devicePtr, size)); auto registeredMemory = communicator->registerMemory(devicePtr, size, mscclpp::Transport::CudaIpc | myIbDevice); + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + auto serialized = registeredMemory.serialize(); + int serializedSize = serialized.size(); + bootstrap->send(&serializedSize, sizeof(int), i, 0); + bootstrap->send(serialized.data(), serializedSize, i, 1); + } + } + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + int deserializedSize; + bootstrap->recv(&deserializedSize, sizeof(int), i, 0); + std::vector deserialized(deserializedSize); + bootstrap->recv(deserialized.data(), deserializedSize, i, 1); + // auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + } + } + + if (bootstrap->getRank() == 0) std::cout << "Memory registeration passed" << std::endl;