mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Torch integration (#692)
Reorganize current native algorithm implementation and DSL algorithm implementation. Provide unified API for DSL algo and native algo and provide interface to tune the algo Provide interface for pytorch integration with native API and DSL --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
This commit is contained in:
60
python/mscclpp/ext/algorithm_collection_builder.py
Normal file
60
python/mscclpp/ext/algorithm_collection_builder.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Union
|
||||
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection
|
||||
import atexit
|
||||
|
||||
from mscclpp._mscclpp import (
|
||||
AlgorithmCollectionBuilder as _AlgorithmCollectionBuilder,
|
||||
)
|
||||
|
||||
__all__ = ["AlgorithmCollectionBuilder"]
|
||||
|
||||
|
||||
class AlgorithmCollectionBuilder:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(AlgorithmCollectionBuilder, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
if cls._instance is not None:
|
||||
_AlgorithmCollectionBuilder.reset()
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._builder = _AlgorithmCollectionBuilder.get_instance()
|
||||
self._initialized = True
|
||||
|
||||
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
|
||||
if isinstance(algorithm_builder, AlgorithmBuilder):
|
||||
self._builder.add_algorithm_builder(algorithm_builder._algorithm_builder)
|
||||
return
|
||||
if isinstance(algorithm_builder, Algorithm):
|
||||
if algorithm_builder.is_dsl_algorithm():
|
||||
self._builder.add_dsl_algorithm_builder(algorithm_builder._algorithm)
|
||||
return
|
||||
raise ValueError("The 'algorithm_builder' argument must be an instance of AlgorithmBuilder or DSL Algorithm.")
|
||||
|
||||
def set_algorithm_selector(self, selector):
|
||||
self._builder.set_algorithm_selector(selector)
|
||||
|
||||
def set_fallback_algorithm_selector(self, selector):
|
||||
self._builder.set_fallback_algorithm_selector(selector)
|
||||
|
||||
def build(self) -> AlgorithmCollection:
|
||||
collection = self._builder.build()
|
||||
return AlgorithmCollection(collection)
|
||||
|
||||
def build_default_algorithms(self, scratch_buffer: int, scratch_buffer_size: int, rank: int) -> AlgorithmCollection:
|
||||
native_collection = self._builder.build_default_algorithms(int(scratch_buffer), scratch_buffer_size, rank)
|
||||
return AlgorithmCollection(native_collection)
|
||||
|
||||
|
||||
atexit.register(AlgorithmCollectionBuilder.reset)
|
||||
Reference in New Issue
Block a user