Fix correctness issue when mscclppDisableChannelCache set to true (#483)

If `mscclppDisableChannelCache` set to true, we need to keep every
channel information avoid the channel info in GPU side be released.
This commit is contained in:
Binyang Li
2025-03-19 14:55:37 -07:00
committed by GitHub
parent b6a179faff
commit 89f7573adf

View File

@@ -92,6 +92,7 @@ struct ncclComm {
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
std::vector<ChannelInfo> channelInfos;
uint32_t numScratchBuff;
uint32_t buffFlag;
@@ -274,10 +275,14 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupMemoryChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
if (mscclppDisableChannelCache == true) {
comm->channelInfos.push_back(channelInfo);
}
}
memoryChannels = sendIt->second.memoryChannelDeviceHandles.get();
memoryOutChannels = recvIt->second.memoryChannelDeviceHandles.get();
memoryOutChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get()
: recvIt->second.memoryChannelDeviceHandles.get();
}
Op reduceOp = getReduceOp(op);
@@ -356,8 +361,12 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle(memoryChannel); });
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
if (mscclppDisableChannelCache == true) {
comm->channelInfos.push_back(channelInfo);
}
}
memoryChannels = it->second.memoryChannelDeviceHandles.get();
memoryChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get()
: it->second.memoryChannelDeviceHandles.get();
};
if (bytes <= 32 * (1 << 20)) {