mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 21:39:21 +00:00
Fix correctness issue when mscclppDisableChannelCache set to true (#483)
If `mscclppDisableChannelCache` set to true, we need to keep every channel information avoid the channel info in GPU side be released.
This commit is contained in:
@@ -92,6 +92,7 @@ struct ncclComm {
|
||||
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
|
||||
std::shared_ptr<char> scratchBuff;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
|
||||
std::vector<ChannelInfo> channelInfos;
|
||||
|
||||
uint32_t numScratchBuff;
|
||||
uint32_t buffFlag;
|
||||
@@ -274,10 +275,14 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
||||
setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
|
||||
ChannelInfo channelInfo{outChannels, setupMemoryChannelDeviceHandles(outChannels)};
|
||||
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
|
||||
if (mscclppDisableChannelCache == true) {
|
||||
comm->channelInfos.push_back(channelInfo);
|
||||
}
|
||||
}
|
||||
|
||||
memoryChannels = sendIt->second.memoryChannelDeviceHandles.get();
|
||||
memoryOutChannels = recvIt->second.memoryChannelDeviceHandles.get();
|
||||
memoryOutChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get()
|
||||
: recvIt->second.memoryChannelDeviceHandles.get();
|
||||
}
|
||||
|
||||
Op reduceOp = getReduceOp(op);
|
||||
@@ -356,8 +361,12 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
|
||||
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle(memoryChannel); });
|
||||
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)};
|
||||
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
|
||||
if (mscclppDisableChannelCache == true) {
|
||||
comm->channelInfos.push_back(channelInfo);
|
||||
}
|
||||
}
|
||||
memoryChannels = it->second.memoryChannelDeviceHandles.get();
|
||||
memoryChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get()
|
||||
: it->second.memoryChannelDeviceHandles.get();
|
||||
};
|
||||
|
||||
if (bytes <= 32 * (1 << 20)) {
|
||||
|
||||
Reference in New Issue
Block a user