From 5b7e76cae41f6d3eeb58a5eed4bbd80120efa4b6 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Mon, 1 May 2023 22:25:14 +0000 Subject: [PATCH] all tests are passing with memory registeration --- src/connection.cc | 19 +++++++++++++++---- src/registered_memory.cc | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 5289ab59..2cfa7205 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,3 +1,4 @@ +#include #include "connection.hpp" #include "checks.hpp" #include "infiniband/verbs.h" @@ -142,15 +143,25 @@ void IBConnection::flush() void IBConnection::startSetup(std::shared_ptr bootstrap) { - bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank(), tag()); - bootstrap->send(&transport_, sizeof(transport_), remoteRank(), tag()); + std::vector ibQpTransport; + std::copy_n(reinterpret_cast(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport)); + std::copy_n(reinterpret_cast(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport)); + + bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); } void IBConnection::endSetup(std::shared_ptr bootstrap) { + std::vector ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport)); + bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); + IbQpInfo qpInfo; - bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank(), tag()); - bootstrap->recv(&remoteTransport_, sizeof(remoteTransport_), remoteRank(), tag()); + auto it = ibQpTransport.begin(); + std::copy_n(it, sizeof(qpInfo), reinterpret_cast(&qpInfo)); + it += sizeof(qpInfo); + std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast(&remoteTransport_)); + it += sizeof(qpInfo); + qp->rtr(qpInfo); qp->rts(); } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 1215c0e2..abf17a8b 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -30,6 +30,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t transportInfo.ibLocal = true; transportInfo.ibMrInfo = mr->getInfo(); this->transportInfos.push_back(transportInfo); + INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size); }; if (transports.has(Transport::IB0)) addIb(Transport::IB0);