diff --git a/python/mscclpp/__main__.py b/python/mscclpp/__main__.py index 6a6f5f28..cab500ce 100644 --- a/python/mscclpp/__main__.py +++ b/python/mscclpp/__main__.py @@ -53,6 +53,26 @@ default_algo_configs = [ ), "additional_kwargs": {"thread_block_group_size": 4}, }, + { + "filename": "allgather_mrc.json", + "function": def_algo.allgather_mrc, + "spec": AlgoSpec( + name="allgather_mrc", + collective=AllGather(2, 1, True), + nranks_per_node=1, + world_size=2, + in_place=True, + instances=1, + protocol="Simple", + auto_sync=True, + num_threads_per_block=1024, + reuse_resources=False, + use_double_scratch_buffer=False, + min_message_size=1 << 10, + max_message_size=8 << 30, + tags={"default": 1}, + ) + } ] diff --git a/python/mscclpp/default_algos/__init__.py b/python/mscclpp/default_algos/__init__.py index a5cfa882..10cf339e 100644 --- a/python/mscclpp/default_algos/__init__.py +++ b/python/mscclpp/default_algos/__init__.py @@ -2,5 +2,6 @@ # Licensed under the MIT License. from mscclpp.default_algos.allreduce_2nodes import allreduce_2nodes +from mscclpp.default_algos.allgather_mrc import allgather_mrc -__all__ = ["allreduce_2nodes"] +__all__ = ["allreduce_2nodes", "allgather_mrc"] diff --git a/python/mscclpp/language/tests/multi_node/allgather_mrc.py b/python/mscclpp/default_algos/allgather_mrc.py similarity index 64% rename from python/mscclpp/language/tests/multi_node/allgather_mrc.py rename to python/mscclpp/default_algos/allgather_mrc.py index 88d4527d..1950f105 100644 --- a/python/mscclpp/language/tests/multi_node/allgather_mrc.py +++ b/python/mscclpp/default_algos/allgather_mrc.py @@ -8,23 +8,13 @@ from mscclpp.language.general import * from mscclpp.language.program import * from mscclpp.language.collectives import * from mscclpp.language.loop import LoopIterationContext +from mscclpp.language.utils import AlgoSpec -def allgather_hierarchical(name, gpus, num_threads_per_block, min_message_size, max_message_size): - size = gpus - chunksperloop = 1 - collective = AllGather(size, chunksperloop, True) - with CollectiveProgram( - name, - collective, - size, - protocol="Simple", - num_threads_per_block=num_threads_per_block, - instances=1, - use_double_scratch_buffer=False, - min_message_size=min_message_size, - max_message_size=max_message_size, - ): +def allgather_mrc(spec: AlgoSpec) -> CollectiveProgram: + size = spec.world_size + + with CollectiveProgram.from_spec(spec) as prog: # Port channels for inter-node communication port_channels = {} for n in range(size): @@ -65,18 +55,4 @@ def allgather_hierarchical(name, gpus, num_threads_per_block, min_message_size, recv_src_chunk = Rank(src_rank).get_output_buffer()[recv_offset:recv_offset + 1] ch_from_prev.wait(tb=0) - print(JSON()) - - -parser = argparse.ArgumentParser() -parser.add_argument("--name", type=str, help="name of the program") -parser.add_argument("--num_gpus", type=int, help="number of gpus") -parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") -parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") -parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") - -args = parser.parse_args() - -allgather_hierarchical( - args.name, args.num_gpus, args.num_threads_per_block, args.min_message_size, args.max_message_size -) + return prog diff --git a/src/ext/collectives/algorithm_collection_builder.cc b/src/ext/collectives/algorithm_collection_builder.cc index 2a7e6e91..1ea4adb0 100644 --- a/src/ext/collectives/algorithm_collection_builder.cc +++ b/src/ext/collectives/algorithm_collection_builder.cc @@ -113,7 +113,8 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultDslAlgorithms(int ra }; static const std::vector defaultAlgoConfigs = { {"allreduce_2nodes_1K_64K.json", "allreduce", 8, 16, {{"default", 1}}}, - {"allreduce_2nodes_64K_2M.json", "allreduce", 8, 16, {{"default", 1}}}}; + {"allreduce_2nodes_64K_2M.json", "allreduce", 8, 16, {{"default", 1}}}, + {"allgather_mrc.json", "allgather", 1, 2, {{"default", 1}}}}; AlgorithmCollection collection; static auto generateFileId = [](const std::string& input) {