address flagBuffer ownership issue (#749)

This pull request updates the handling of the default flag buffer in the
C++ and Python bindings to ensure proper memory management when
interfacing with Python.

Make sure the buffer will not be deallocated when transfer ownership
from cpp to python
This commit is contained in:
Binyang Li
2026-02-20 13:42:29 -08:00
committed by GitHub
parent 4701ae3a95
commit 39865c218b
6 changed files with 39 additions and 22 deletions

View File

@@ -19,7 +19,7 @@ from mscclpp._mscclpp import (
CppReduceOp,
CppAlgorithmBuilder,
CppAlgorithmCollection,
cpp_get_default_flag_buffer,
cpp_get_flag_buffer,
)
__all__ = ["Algorithm", "AlgorithmBuilder", "AlgorithmCollection"]
@@ -241,15 +241,22 @@ class AlgorithmCollection:
self._algorithms.append(algorithm)
def get_default_flag_buffer() -> cp.ndarray:
_flag_buffer_cache = None
def get_flag_buffer() -> cp.ndarray:
"""Get the default flag buffer for algorithm selection.
This buffer is used internally by default algorithms to store selection flags.
It is allocated as a shared GPU buffer and can be accessed from Python.
The result is cached so all callers share the same buffer.
Returns:
A CuPy array representing the flag buffer on the GPU.
"""
buffer_ptr, buffer_size = cpp_get_default_flag_buffer()
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, None), 0)
return cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr)
global _flag_buffer_cache
if _flag_buffer_cache is None:
buffer_ptr, buffer_size, owner = cpp_get_flag_buffer()
memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, owner), 0)
_flag_buffer_cache = cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr)
return _flag_buffer_cache

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Union
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_default_flag_buffer
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_flag_buffer
import atexit
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
@@ -58,7 +58,7 @@ class AlgorithmCollectionBuilder:
rank: int,
) -> AlgorithmCollection:
if self._flag_buffer is None:
self._flag_buffer = get_default_flag_buffer()
self._flag_buffer = get_flag_buffer()
native_collection = self._builder.build_default_algorithms(
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
)