Use IB transport flags only when an IB device exists (#355)

This commit is contained in:
Changho Hwang
2024-09-19 00:13:11 -07:00
committed by GitHub
parent 5c4e105814
commit 74130c7c5e

View File

@@ -15,6 +15,7 @@
#include <fstream>
#include <iomanip>
#include <iostream>
#include <mscclpp/core.hpp>
#include <mscclpp/utils.hpp>
#include <nlohmann/json.hpp>
#include <sstream>
@@ -399,7 +400,8 @@ void BaseTestEngine::setupMeshConnectionsInternal(
void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels,
void* inputBuff, size_t inputBuffBytes, void* outputBuff,
size_t outputBuffBytes, SetupChannelFunc setupChannel) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory outputBufRegMem;
if (outputBuff) {
@@ -429,7 +431,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes,
ChannelSemantic semantic, size_t nChannelPerConnection) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory getPacketBufRegMem;
mscclpp::RegisteredMemory outputBufRegMem;
@@ -469,7 +472,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
void* inputBuff, size_t inputBuffBytes, void* putPacketBuff,
size_t putPacketBuffBytes, void* getPacketBuff, size_t getPacketBuffBytes,
void* outputBuff, size_t outputBuffBytes) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory putPacketBufRegMem;
mscclpp::RegisteredMemory getPacketBufRegMem;