mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Refactor algo selection logic and introduce symmetric_memory env (#741)
This PR refactors the algorithm selection logic in MSCCL++ and introduces support for symmetric memory configuration through environment variables. 1. Algorithm Selection Refactoring Use separate class for algo selection. Could introduce more complex logic for algo selection based on message size, arch, if cuda graph is enabled and memory allocation method 2. Symmetric Memory Support Introduced symmetricMemory parameter in algorithm context key generation. Remove disableChannelCache env as is ambiguous 3. Add new args for build_default_algorithms Add flag_buffer, and flag_buffer_size args to build default algorithm. Then we could use unified flag buffer for different algorithms, avoid application hanging when switch algo for different message size. --------- Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com> Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com> Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Union
|
||||
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection
|
||||
from mscclpp._core.algorithm import Algorithm, AlgorithmBuilder, AlgorithmCollection, get_default_flag_buffer
|
||||
import atexit
|
||||
|
||||
from mscclpp._mscclpp import CppAlgorithmCollectionBuilder
|
||||
@@ -29,6 +29,7 @@ class AlgorithmCollectionBuilder:
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._builder = CppAlgorithmCollectionBuilder.get_instance()
|
||||
self._initialized = True
|
||||
self._flag_buffer = None
|
||||
|
||||
def add_algorithm_builder(self, algorithm_builder: Union[AlgorithmBuilder, Algorithm]):
|
||||
if isinstance(algorithm_builder, AlgorithmBuilder):
|
||||
@@ -50,8 +51,17 @@ class AlgorithmCollectionBuilder:
|
||||
collection = self._builder.build()
|
||||
return AlgorithmCollection(collection)
|
||||
|
||||
def build_default_algorithms(self, scratch_buffer: int, scratch_buffer_size: int, rank: int) -> AlgorithmCollection:
|
||||
native_collection = self._builder.build_default_algorithms(int(scratch_buffer), scratch_buffer_size, rank)
|
||||
def build_default_algorithms(
|
||||
self,
|
||||
scratch_buffer: int,
|
||||
scratch_buffer_size: int,
|
||||
rank: int,
|
||||
) -> AlgorithmCollection:
|
||||
if self._flag_buffer is None:
|
||||
self._flag_buffer = get_default_flag_buffer()
|
||||
native_collection = self._builder.build_default_algorithms(
|
||||
int(scratch_buffer), scratch_buffer_size, self._flag_buffer.data.ptr, self._flag_buffer.nbytes, rank
|
||||
)
|
||||
return AlgorithmCollection(native_collection)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user