diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index cb366167..cf475cdf 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -58,8 +58,8 @@ class CustomizedComm: """Exposes all_reduce, all_gather, barrier with lazy per-size tuning.""" _TUNE_N_WARMUP = 5 - _TUNE_N_GRAPH_LAUNCHES = 10 - _TUNE_N_OPS_PER_GRAPH = 100 + _TUNE_N_GRAPH_LAUNCHES = 5 + _TUNE_N_OPS_PER_GRAPH = 20 _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128] _CANDIDATE_NTHREADS = [512, 768, 1024] _NBLOCKS_LIMIT = { diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index f547ab4f..eb872624 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -9,6 +9,17 @@ namespace mscclpp { namespace collective { +namespace { +// Per-context cache of input-side MemoryChannels keyed by input pointer. +// Lifetime is tied to AlgorithmCtx, so entries are released when the ctx is +// evicted from the framework's context cache (avoids unbounded growth across +// allreduce calls that pass different input buffers). +using InputChannelsCache = + std::unordered_map, std::shared_ptr>>>; +constexpr const char* kInputChannelsExtraKey = "inputChannels"; +} // namespace + template __global__ void __launch_bounds__(512, 1) allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, @@ -195,17 +206,17 @@ CommResult AllreduceFullmesh::allreduceKernelFunc( MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output)); channelOutOffset = (char*)output - (char*)recvBasePtr; } - std::shared_ptr> inputChannelHandles; - if (this->memoryChannelsMap_.find(input) != this->memoryChannelsMap_.end()) { - inputChannelHandles = this->memoryChannelsMap_[input].second; - } else { + auto& inputChannelsCache = *static_cast(ctx->extras.at(kInputChannelsExtraKey).get()); + auto it = inputChannelsCache.find(input); + if (it == inputChannelsCache.end()) { RegisteredMemory localMemory = comm_->registerMemory(const_cast(input), inputSize, Transport::CudaIpc); std::vector channels = setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory, nChannelsPerConnection_); - this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels)); + auto handles = setupMemoryChannelDeviceHandles(channels); + it = inputChannelsCache.emplace(input, std::make_pair(std::move(channels), std::move(handles))).first; } - inputChannelHandles = this->memoryChannelsMap_[input].second; + std::shared_ptr> inputChannelHandles = it->second.second; AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { @@ -267,6 +278,7 @@ std::shared_ptr AllreduceFullmesh::initAllreduceContext(std::shared_ptrmemoryChannels = setupMemoryChannels(this->conns_, ctx->memorySemaphores, ctx->registeredMemories, localMemory, nChannelsPerConnection_); ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); + ctx->extras.insert({kInputChannelsExtraKey, std::make_shared()}); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index 877a722a..e7ed0cab 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -211,8 +211,6 @@ std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt // register input and output memories RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc); RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc); - this->inputMemories_.push_back(inputMemory); - this->outputMemories_.push_back(outputMemory); auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory); auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory); diff --git a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp index a54352b3..e0c63a3d 100644 --- a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp @@ -30,8 +30,6 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder { std::vector> inputScratchSemaphores_; std::vector remoteScratchMemories_; RegisteredMemory localScratchMemory_; - std::unordered_map, std::shared_ptr>>> - memoryChannelsMap_; bool symmetricMemory_ = false; }; } // namespace collective diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp index 05bf2ef3..528d9708 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp @@ -27,8 +27,6 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder { int nChannelsPerConnection_; std::vector conns_; std::vector> semaphores_; - std::vector inputMemories_; - std::vector outputMemories_; std::vector baseChannels_; std::shared_ptr> baseMemoryChannelHandles_;