mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
This pull request updates the handling of the default flag buffer in the C++ and Python bindings to ensure proper memory management when interfacing with Python. Make sure the buffer will not be deallocated when transfer ownership from cpp to python
69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from __future__ import annotations
|
|
from typing import Union
|
|
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_flag_buffer
|
|
import atexit
|
|
|
|
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
|
|
|
|
__all__ = ["AlgorithmCollectionBuilder"]
|
|
|
|
|
|
class AlgorithmCollectionBuilder:
|
|
_instance = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(AlgorithmCollectionBuilder, cls).__new__(cls)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
if cls._instance is not None:
|
|
CppAlgorithmCollectionBuilder.reset()
|
|
cls._instance = None
|
|
|
|
def __init__(self):
|
|
if not hasattr(self, "_initialized"):
|
|
self._builder = CppAlgorithmCollectionBuilder.get_instance()
|
|
self._initialized = True
|
|
self._flag_buffer = None
|
|
|
|
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
|
|
if isinstance(algorithm_builder, AlgorithmBuilder):
|
|
self._builder.add_algorithm_builder(algorithm_builder._algorithm_builder)
|
|
return
|
|
if isinstance(algorithm_builder, Algorithm):
|
|
if algorithm_builder.is_dsl_algorithm():
|
|
self._builder.add_dsl_algorithm_builder(algorithm_builder._algorithm)
|
|
return
|
|
raise ValueError("The 'algorithm_builder' argument must be an instance of AlgorithmBuilder or DSL Algorithm.")
|
|
|
|
def set_algorithm_selector(self, selector):
|
|
self._builder.set_algorithm_selector(selector)
|
|
|
|
def set_fallback_algorithm_selector(self, selector):
|
|
self._builder.set_fallback_algorithm_selector(selector)
|
|
|
|
def build(self) -> AlgorithmCollection:
|
|
collection = self._builder.build()
|
|
return AlgorithmCollection(collection)
|
|
|
|
def build_default_algorithms(
|
|
self,
|
|
scratch_buffer: int,
|
|
scratch_buffer_size: int,
|
|
rank: int,
|
|
) -> AlgorithmCollection:
|
|
if self._flag_buffer is None:
|
|
self._flag_buffer = get_flag_buffer()
|
|
native_collection = self._builder.build_default_algorithms(
|
|
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
|
|
)
|
|
return AlgorithmCollection(native_collection)
|
|
|
|
|
|
atexit.register(AlgorithmCollectionBuilder.reset)
|