mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 23:06:17 +00:00
Fix memory leak
This commit is contained in:
@@ -58,8 +58,8 @@ class CustomizedComm:
|
||||
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
|
||||
|
||||
_TUNE_N_WARMUP = 5
|
||||
_TUNE_N_GRAPH_LAUNCHES = 10
|
||||
_TUNE_N_OPS_PER_GRAPH = 100
|
||||
_TUNE_N_GRAPH_LAUNCHES = 5
|
||||
_TUNE_N_OPS_PER_GRAPH = 20
|
||||
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128]
|
||||
_CANDIDATE_NTHREADS = [512, 768, 1024]
|
||||
_NBLOCKS_LIMIT = {
|
||||
|
||||
@@ -9,6 +9,17 @@
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
namespace {
|
||||
// Per-context cache of input-side MemoryChannels keyed by input pointer.
|
||||
// Lifetime is tied to AlgorithmCtx, so entries are released when the ctx is
|
||||
// evicted from the framework's context cache (avoids unbounded growth across
|
||||
// allreduce calls that pass different input buffers).
|
||||
using InputChannelsCache =
|
||||
std::unordered_map<const void*,
|
||||
std::pair<std::vector<MemoryChannel>, std::shared_ptr<DeviceHandle<MemoryChannel>>>>;
|
||||
constexpr const char* kInputChannelsExtraKey = "inputChannels";
|
||||
} // namespace
|
||||
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
@@ -195,17 +206,17 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
|
||||
channelOutOffset = (char*)output - (char*)recvBasePtr;
|
||||
}
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> inputChannelHandles;
|
||||
if (this->memoryChannelsMap_.find(input) != this->memoryChannelsMap_.end()) {
|
||||
inputChannelHandles = this->memoryChannelsMap_[input].second;
|
||||
} else {
|
||||
auto& inputChannelsCache = *static_cast<InputChannelsCache*>(ctx->extras.at(kInputChannelsExtraKey).get());
|
||||
auto it = inputChannelsCache.find(input);
|
||||
if (it == inputChannelsCache.end()) {
|
||||
RegisteredMemory localMemory = comm_->registerMemory(const_cast<void*>(input), inputSize, Transport::CudaIpc);
|
||||
std::vector<MemoryChannel> channels =
|
||||
setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory,
|
||||
nChannelsPerConnection_);
|
||||
this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels));
|
||||
auto handles = setupMemoryChannelDeviceHandles(channels);
|
||||
it = inputChannelsCache.emplace(input, std::make_pair(std::move(channels), std::move(handles))).first;
|
||||
}
|
||||
inputChannelHandles = this->memoryChannelsMap_[input].second;
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> inputChannelHandles = it->second.second;
|
||||
|
||||
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
@@ -267,6 +278,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
|
||||
ctx->memoryChannels = setupMemoryChannels(this->conns_, ctx->memorySemaphores, ctx->registeredMemories, localMemory,
|
||||
nChannelsPerConnection_);
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
ctx->extras.insert({kInputChannelsExtraKey, std::make_shared<InputChannelsCache>()});
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
||||
@@ -211,8 +211,6 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
|
||||
// register input and output memories
|
||||
RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc);
|
||||
RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc);
|
||||
this->inputMemories_.push_back(inputMemory);
|
||||
this->outputMemories_.push_back(outputMemory);
|
||||
|
||||
auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory);
|
||||
auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory);
|
||||
|
||||
@@ -30,8 +30,6 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder {
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> inputScratchSemaphores_;
|
||||
std::vector<RegisteredMemory> remoteScratchMemories_;
|
||||
RegisteredMemory localScratchMemory_;
|
||||
std::unordered_map<const void*, std::pair<std::vector<MemoryChannel>, std::shared_ptr<DeviceHandle<MemoryChannel>>>>
|
||||
memoryChannelsMap_;
|
||||
bool symmetricMemory_ = false;
|
||||
};
|
||||
} // namespace collective
|
||||
|
||||
@@ -27,8 +27,6 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
|
||||
int nChannelsPerConnection_;
|
||||
std::vector<Connection> conns_;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> inputMemories_;
|
||||
std::vector<RegisteredMemory> outputMemories_;
|
||||
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
|
||||
|
||||
Reference in New Issue
Block a user