diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 4f35fdab..1bac2155 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -92,6 +92,7 @@ struct ncclComm { std::unordered_map channelScratchInfos; std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; + std::vector channelInfos; uint32_t numScratchBuff; uint32_t buffFlag; @@ -274,10 +275,14 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, setupMemoryChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); ChannelInfo channelInfo{outChannels, setupMemoryChannelDeviceHandles(outChannels)}; recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if (mscclppDisableChannelCache == true) { + comm->channelInfos.push_back(channelInfo); + } } memoryChannels = sendIt->second.memoryChannelDeviceHandles.get(); - memoryOutChannels = recvIt->second.memoryChannelDeviceHandles.get(); + memoryOutChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get() + : recvIt->second.memoryChannelDeviceHandles.get(); } Op reduceOp = getReduceOp(op); @@ -356,8 +361,12 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, [](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle(memoryChannel); }); ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)}; it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if (mscclppDisableChannelCache == true) { + comm->channelInfos.push_back(channelInfo); + } } - memoryChannels = it->second.memoryChannelDeviceHandles.get(); + memoryChannels = mscclppDisableChannelCache == true ? comm->channelInfos.back().memoryChannelDeviceHandles.get() + : it->second.memoryChannelDeviceHandles.get(); }; if (bytes <= 32 * (1 << 20)) {