Files
mscclpp/python/mscclpp/ext/algorithm_collection_builder.py
Binyang Li bd68319e3e Refactor algo selection logic and introduce symmetric_memory env (#741)
This PR refactors the algorithm selection logic in MSCCL++ and
introduces support for symmetric memory configuration through
environment variables.


1. Algorithm Selection Refactoring
Use separate class for algo selection. Could introduce more complex
logic for algo selection based on message size, arch, if cuda graph is
enabled and memory allocation method

2. Symmetric Memory Support
Introduced symmetricMemory parameter in algorithm context key
generation. Remove disableChannelCache env as is ambiguous

3. Add new args for build_default_algorithms 
Add flag_buffer, and flag_buffer_size args to build default algorithm.
Then we could use unified flag buffer for different algorithms, avoid
application hanging when switch algo for different message size.

---------

Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
2026-02-12 19:06:18 -08:00

69 lines
2.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Union
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_default_flag_buffer
import atexit
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
__all__ = ["AlgorithmCollectionBuilder"]
class AlgorithmCollectionBuilder:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(AlgorithmCollectionBuilder, cls).__new__(cls)
return cls._instance
@classmethod
def reset(cls):
if cls._instance is not None:
CppAlgorithmCollectionBuilder.reset()
cls._instance = None
def __init__(self):
if not hasattr(self, "_initialized"):
self._builder = CppAlgorithmCollectionBuilder.get_instance()
self._initialized = True
self._flag_buffer = None
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
if isinstance(algorithm_builder, AlgorithmBuilder):
self._builder.add_algorithm_builder(algorithm_builder._algorithm_builder)
return
if isinstance(algorithm_builder, Algorithm):
if algorithm_builder.is_dsl_algorithm():
self._builder.add_dsl_algorithm_builder(algorithm_builder._algorithm)
return
raise ValueError("The 'algorithm_builder' argument must be an instance of AlgorithmBuilder or DSL Algorithm.")
def set_algorithm_selector(self, selector):
self._builder.set_algorithm_selector(selector)
def set_fallback_algorithm_selector(self, selector):
self._builder.set_fallback_algorithm_selector(selector)
def build(self) -> AlgorithmCollection:
collection = self._builder.build()
return AlgorithmCollection(collection)
def build_default_algorithms(
self,
scratch_buffer: int,
scratch_buffer_size: int,
rank: int,
) -> AlgorithmCollection:
if self._flag_buffer is None:
self._flag_buffer = get_default_flag_buffer()
native_collection = self._builder.build_default_algorithms(
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
)
return AlgorithmCollection(native_collection)
atexit.register(AlgorithmCollectionBuilder.reset)