Fix memory leak

This commit is contained in:
Binyang Li
2026-05-24 05:56:11 +00:00
parent 7308c321a0
commit 42ece408b9
5 changed files with 20 additions and 14 deletions

View File

@@ -58,8 +58,8 @@ class CustomizedComm:
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
_TUNE_N_WARMUP = 5
_TUNE_N_GRAPH_LAUNCHES = 10
_TUNE_N_OPS_PER_GRAPH = 100
_TUNE_N_GRAPH_LAUNCHES = 5
_TUNE_N_OPS_PER_GRAPH = 20
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {

View File

@@ -9,6 +9,17 @@
namespace mscclpp {
namespace collective {
namespace {
// Per-context cache of input-side MemoryChannels keyed by input pointer.
// Lifetime is tied to AlgorithmCtx, so entries are released when the ctx is
// evicted from the framework's context cache (avoids unbounded growth across
// allreduce calls that pass different input buffers).
using InputChannelsCache =
std::unordered_map<const void*,
std::pair<std::vector<MemoryChannel>, std::shared_ptr<DeviceHandle<MemoryChannel>>>>;
constexpr const char* kInputChannelsExtraKey = "inputChannels";
} // namespace
template <ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(512, 1)
allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
@@ -195,17 +206,17 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
channelOutOffset = (char*)output - (char*)recvBasePtr;
}
std::shared_ptr<DeviceHandle<MemoryChannel>> inputChannelHandles;
if (this->memoryChannelsMap_.find(input) != this->memoryChannelsMap_.end()) {
inputChannelHandles = this->memoryChannelsMap_[input].second;
} else {
auto& inputChannelsCache = *static_cast<InputChannelsCache*>(ctx->extras.at(kInputChannelsExtraKey).get());
auto it = inputChannelsCache.find(input);
if (it == inputChannelsCache.end()) {
RegisteredMemory localMemory = comm_->registerMemory(const_cast<void*>(input), inputSize, Transport::CudaIpc);
std::vector<MemoryChannel> channels =
setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory,
nChannelsPerConnection_);
this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels));
auto handles = setupMemoryChannelDeviceHandles(channels);
it = inputChannelsCache.emplace(input, std::make_pair(std::move(channels), std::move(handles))).first;
}
inputChannelHandles = this->memoryChannelsMap_[input].second;
std::shared_ptr<DeviceHandle<MemoryChannel>> inputChannelHandles = it->second.second;
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype, accumDtype);
if (!allreduce) {
@@ -267,6 +278,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
ctx->memoryChannels = setupMemoryChannels(this->conns_, ctx->memorySemaphores, ctx->registeredMemories, localMemory,
nChannelsPerConnection_);
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
ctx->extras.insert({kInputChannelsExtraKey, std::make_shared<InputChannelsCache>()});
return ctx;
}

View File

@@ -211,8 +211,6 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
// register input and output memories
RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc);
RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc);
this->inputMemories_.push_back(inputMemory);
this->outputMemories_.push_back(outputMemory);
auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory);
auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory);

View File

@@ -30,8 +30,6 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder {
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> inputScratchSemaphores_;
std::vector<RegisteredMemory> remoteScratchMemories_;
RegisteredMemory localScratchMemory_;
std::unordered_map<const void*, std::pair<std::vector<MemoryChannel>, std::shared_ptr<DeviceHandle<MemoryChannel>>>>
memoryChannelsMap_;
bool symmetricMemory_ = false;
};
} // namespace collective

View File

@@ -27,8 +27,6 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
int nChannelsPerConnection_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> inputMemories_;
std::vector<RegisteredMemory> outputMemories_;
std::vector<BaseMemoryChannel> baseChannels_;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;