mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
address flagBuffer ownership issue (#749)
This pull request updates the handling of the default flag buffer in the C++ and Python bindings to ensure proper memory management when interfacing with Python. Make sure the buffer will not be deallocated when transfer ownership from cpp to python
This commit is contained in:
@@ -199,18 +199,23 @@ std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
|
||||
// TODO: implement this
|
||||
void DslAlgorithm::reset() {}
|
||||
|
||||
static std::weak_ptr<uint32_t> gDefaultFlagBuffer;
|
||||
static uint32_t* gDefaultFlagBuffer = nullptr;
|
||||
static std::weak_ptr<void> gDefaultFlagBufferWeak;
|
||||
static size_t gDefaultFlagCount = 128;
|
||||
|
||||
std::pair<std::shared_ptr<void>, size_t> getDefaultFlagBuffer() {
|
||||
std::shared_ptr<uint32_t> flagBuffer = gDefaultFlagBuffer.lock();
|
||||
if (!flagBuffer) {
|
||||
flagBuffer = mscclpp::detail::gpuCallocShared<uint32_t>(gDefaultFlagCount);
|
||||
std::vector<uint32_t> initFlags(gDefaultFlagCount, 1);
|
||||
mscclpp::gpuMemcpy(flagBuffer.get(), initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice);
|
||||
gDefaultFlagBuffer = flagBuffer;
|
||||
std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
|
||||
auto ptr = gDefaultFlagBufferWeak.lock();
|
||||
if (!ptr) {
|
||||
if (!gDefaultFlagBuffer) {
|
||||
// Intentionally never freed — CUDA driver reclaims GPU memory at process exit.
|
||||
gDefaultFlagBuffer = static_cast<uint32_t*>(mscclpp::detail::gpuCalloc(gDefaultFlagCount * sizeof(uint32_t)));
|
||||
std::vector<uint32_t> initFlags(gDefaultFlagCount, 1);
|
||||
mscclpp::gpuMemcpy(gDefaultFlagBuffer, initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice);
|
||||
}
|
||||
ptr = std::shared_ptr<void>(gDefaultFlagBuffer, [](void*) {});
|
||||
gDefaultFlagBufferWeak = ptr;
|
||||
}
|
||||
return {flagBuffer, gDefaultFlagCount * sizeof(uint32_t)};
|
||||
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
Reference in New Issue
Block a user