Fix for multi-nodes test (#614)

Fix multi-node test

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2025-08-14 20:44:43 -07:00
committed by GitHub
parent 671b688bb3
commit 03c0ff2a91
3 changed files with 36 additions and 26 deletions

View File

@@ -414,8 +414,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Port
mscclpp::RegisteredMemory& localRegMemory = (outputBuff) ? outputBufRegMem : inputBufRegMem;
// store memory to keep resource alive
inputMemory_ = inputBufRegMem;
outputMemory_ = outputBufRegMem;
inputMemories_.push_back(inputBufRegMem);
outputMemories_.push_back(outputBufRegMem);
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
if (setupChannel != nullptr) {
@@ -446,8 +446,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
mscclpp::RegisteredMemory& localRegMemory =
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
// store memory to keep resource alive
inputMemory_ = inputBufRegMem;
outputMemory_ = outputBufRegMem;
inputMemories_.push_back(inputBufRegMem);
outputMemories_.push_back(outputBufRegMem);
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
std::unordered_map<size_t, std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>> memorySemaphores;
@@ -498,8 +498,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
(getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem);
// store memory to keep resource alive
scratchMemory_ = getPacketBufRegMem;
inputMemory_ = inputBufRegMem;
outputMemory_ = outputBufRegMem;
inputMemories_.push_back(inputBufRegMem);
outputMemories_.push_back(outputBufRegMem);
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);

View File

@@ -132,8 +132,8 @@ class BaseTestEngine {
std::shared_ptr<mscclpp::Communicator> comm_;
std::shared_ptr<mscclpp::BaseProxyService> chanService_;
mscclpp::RegisteredMemory scratchMemory_;
mscclpp::RegisteredMemory inputMemory_;
mscclpp::RegisteredMemory outputMemory_;
std::vector<mscclpp::RegisteredMemory> inputMemories_;
std::vector<mscclpp::RegisteredMemory> outputMemories_;
cudaStream_t stream_;
int error_;
};