diff --git a/include/mscclpp/algorithm.hpp b/include/mscclpp/algorithm.hpp index 6cc05ad4..07149cab 100644 --- a/include/mscclpp/algorithm.hpp +++ b/include/mscclpp/algorithm.hpp @@ -366,7 +366,7 @@ class AlgorithmCollection { /// Get a default GPU flag buffer (allocated once and reused). /// @return A pair of (shared_ptr to the flag buffer, size in bytes). -std::pair, size_t> getDefaultFlagBuffer(); +std::pair, size_t> getFlagBuffer(); } // namespace mscclpp diff --git a/python/csrc/algorithm.cpp b/python/csrc/algorithm.cpp index c8365566..f0d8980d 100644 --- a/python/csrc/algorithm.cpp +++ b/python/csrc/algorithm.cpp @@ -116,10 +116,15 @@ void register_algorithm(nb::module_& m) { .def("buffer_mode", &CollectiveRequest::bufferMode); m.def( - "cpp_get_default_flag_buffer", + "cpp_get_flag_buffer", []() { - auto [buffer, size] = getDefaultFlagBuffer(); - return std::make_pair(reinterpret_cast(buffer.get()), size); + auto [buffer, size] = getFlagBuffer(); + uintptr_t ptr = reinterpret_cast(buffer.get()); + // Transfer shared_ptr ownership into a capsule so Python's GC manages the lifetime. + auto prevent = std::make_unique>(std::move(buffer)); + nb::capsule owner(prevent.get(), [](void* p) noexcept { delete static_cast*>(p); }); + prevent.release(); // capsule now owns the pointer + return nb::make_tuple(ptr, size, owner); }, - "Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size)."); + "Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size, owner)."); } \ No newline at end of file diff --git a/python/mscclpp/_core/algorithm.py b/python/mscclpp/_core/algorithm.py index c712bf88..9b870582 100644 --- a/python/mscclpp/_core/algorithm.py +++ b/python/mscclpp/_core/algorithm.py @@ -19,7 +19,7 @@ from mscclpp._mscclpp import ( CppReduceOp, CppAlgorithmBuilder, CppAlgorithmCollection, - cpp_get_default_flag_buffer, + cpp_get_flag_buffer, ) __all__ = ["Algorithm", "AlgorithmBuilder", "AlgorithmCollection"] @@ -241,15 +241,22 @@ class AlgorithmCollection: self._algorithms.append(algorithm) -def get_default_flag_buffer() -> cp.ndarray: +_flag_buffer_cache = None + + +def get_flag_buffer() -> cp.ndarray: """Get the default flag buffer for algorithm selection. This buffer is used internally by default algorithms to store selection flags. It is allocated as a shared GPU buffer and can be accessed from Python. + The result is cached so all callers share the same buffer. Returns: A CuPy array representing the flag buffer on the GPU. """ - buffer_ptr, buffer_size = cpp_get_default_flag_buffer() - memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, None), 0) - return cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr) + global _flag_buffer_cache + if _flag_buffer_cache is None: + buffer_ptr, buffer_size, owner = cpp_get_flag_buffer() + memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer_ptr, buffer_size, owner), 0) + _flag_buffer_cache = cp.ndarray((buffer_size // 4,), dtype=cp.uint32, memptr=memptr) + return _flag_buffer_cache diff --git a/python/mscclpp/ext/algorithm_collection_builder.py b/python/mscclpp/ext/algorithm_collection_builder.py index 80c68909..ddfb929f 100644 --- a/python/mscclpp/ext/algorithm_collection_builder.py +++ b/python/mscclpp/ext/algorithm_collection_builder.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Union -from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_default_flag_buffer +from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_flag_buffer import atexit from mscclpp._mscclpp import CppAlgorithmCollectionBuilder @@ -58,7 +58,7 @@ class AlgorithmCollectionBuilder: rank: int, ) -> AlgorithmCollection: if self._flag_buffer is None: - self._flag_buffer = get_default_flag_buffer() + self._flag_buffer = get_flag_buffer() native_collection = self._builder.build_default_algorithms( int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank ) diff --git a/src/core/algorithm.cc b/src/core/algorithm.cc index 98ac5520..683d4ddd 100644 --- a/src/core/algorithm.cc +++ b/src/core/algorithm.cc @@ -199,18 +199,23 @@ std::shared_ptr DslAlgorithm::build() { return shared_from_this(); } // TODO: implement this void DslAlgorithm::reset() {} -static std::weak_ptr gDefaultFlagBuffer; +static uint32_t* gDefaultFlagBuffer = nullptr; +static std::weak_ptr gDefaultFlagBufferWeak; static size_t gDefaultFlagCount = 128; -std::pair, size_t> getDefaultFlagBuffer() { - std::shared_ptr flagBuffer = gDefaultFlagBuffer.lock(); - if (!flagBuffer) { - flagBuffer = mscclpp::detail::gpuCallocShared(gDefaultFlagCount); - std::vector initFlags(gDefaultFlagCount, 1); - mscclpp::gpuMemcpy(flagBuffer.get(), initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice); - gDefaultFlagBuffer = flagBuffer; +std::pair, size_t> getFlagBuffer() { + auto ptr = gDefaultFlagBufferWeak.lock(); + if (!ptr) { + if (!gDefaultFlagBuffer) { + // Intentionally never freed — CUDA driver reclaims GPU memory at process exit. + gDefaultFlagBuffer = static_cast(mscclpp::detail::gpuCalloc(gDefaultFlagCount * sizeof(uint32_t))); + std::vector initFlags(gDefaultFlagCount, 1); + mscclpp::gpuMemcpy(gDefaultFlagBuffer, initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice); + } + ptr = std::shared_ptr(gDefaultFlagBuffer, [](void*) {}); + gDefaultFlagBufferWeak = ptr; } - return {flagBuffer, gDefaultFlagCount * sizeof(uint32_t)}; + return {ptr, gDefaultFlagCount * sizeof(uint32_t)}; } } // namespace mscclpp \ No newline at end of file diff --git a/src/ext/nccl/nccl.cc b/src/ext/nccl/nccl.cc index bfde4786..afeb5bdb 100644 --- a/src/ext/nccl/nccl.cc +++ b/src/ext/nccl/nccl.cc @@ -294,7 +294,7 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI commPtr->scratchBuffer_ = mscclpp::GpuBuffer(commPtr->scratchBufferSize_).memory(); commPtr->executor = std::make_shared(mscclppComm, commPtr->scratchBuffer_); - auto [buffer, size] = mscclpp::getDefaultFlagBuffer(); + auto [buffer, size] = mscclpp::getFlagBuffer(); commPtr->flagBuffer_ = buffer; commPtr->flagBufferSize_ = size;