diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index bef521af..212f31fe 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -44,20 +44,26 @@ void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; } void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) { std::vector>> connectionFutures(numRanksToUse); + std::vector>> cpuConnectionFutures(numRanksToUse); for (int i = 0; i < numRanksToUse; i++) { if (i != gEnv->rank) { if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) { connectionFutures[i] = communicator->connect(mscclpp::Transport::CudaIpc, i); } else if (useIb) { connectionFutures[i] = communicator->connect(ibTransport, i); + cpuConnectionFutures[i] = communicator->connect({ibTransport, mscclpp::DeviceType::CPU}, i); } else if (useEthernet) { connectionFutures[i] = communicator->connect(mscclpp::Transport::Ethernet, i); + cpuConnectionFutures[i] = communicator->connect({mscclpp::Transport::Ethernet, mscclpp::DeviceType::CPU}, i); } } } for (int i = 0; i < numRanksToUse; i++) { if (i != gEnv->rank) { connections[i] = connectionFutures[i].get(); + if (cpuConnectionFutures[i].valid()) { + cpuConnections[i] = cpuConnectionFutures[i].get(); + } } } } @@ -241,7 +247,8 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) { if (gEnv->rank >= numRanksToUse) return; std::unordered_map> semaphores; - for (auto entry : connections) { + // Iterate over cpuConnections, which is guaranteed by design to only contain CPU connections. + for (auto entry : cpuConnections) { auto& conn = entry.second; // Host2HostSemaphore cannot be used with CudaIpc transport if (conn->transport() == mscclpp::Transport::CudaIpc) continue; diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 98b58cf7..654019fa 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -108,6 +108,7 @@ class CommunicatorTestBase : public MultiProcessTest { std::shared_ptr communicator; mscclpp::Transport ibTransport; std::unordered_map> connections; + std::unordered_map> cpuConnections; }; class CommunicatorTest : public CommunicatorTestBase {