diff --git a/python/mscclpp/__main__.py b/python/mscclpp/__main__.py index cab500ce..41f6fdc5 100644 --- a/python/mscclpp/__main__.py +++ b/python/mscclpp/__main__.py @@ -54,10 +54,10 @@ default_algo_configs = [ "additional_kwargs": {"thread_block_group_size": 4}, }, { - "filename": "allgather_mrc.json", + "filename": "allgather_mrc_2_nodes.json", "function": def_algo.allgather_mrc, "spec": AlgoSpec( - name="allgather_mrc", + name="allgather_mrc_2_nodes", collective=AllGather(2, 1, True), nranks_per_node=1, world_size=2, @@ -72,6 +72,26 @@ default_algo_configs = [ max_message_size=8 << 30, tags={"default": 1}, ) + }, + { + "filename": "allgather_mrc_4_nodes.json", + "function": def_algo.allgather_mrc, + "spec": AlgoSpec( + name="allgather_mrc_4_nodes", + collective=AllGather(4, 1, True), + nranks_per_node=1, + world_size=4, + in_place=True, + instances=1, + protocol="Simple", + auto_sync=True, + num_threads_per_block=1024, + reuse_resources=False, + use_double_scratch_buffer=False, + min_message_size=1 << 10, + max_message_size=8 << 30, + tags={"default": 1}, + ) } ] diff --git a/python/mscclpp/default_algos/allgather_mrc.py b/python/mscclpp/default_algos/allgather_mrc.py index 1950f105..9368090f 100644 --- a/python/mscclpp/default_algos/allgather_mrc.py +++ b/python/mscclpp/default_algos/allgather_mrc.py @@ -56,3 +56,35 @@ def allgather_mrc(spec: AlgoSpec) -> CollectiveProgram: ch_from_prev.wait(tb=0) return prog + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--name", type=str, help="name of the program") + parser.add_argument("--num_gpus", type=int, help="total number of gpus") + parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node") + parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block") + parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size") + parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size") + + args = parser.parse_args() + + spec = AlgoSpec( + name=args.name, + collective=AllGather(args.num_gpus, 1, True), + nranks_per_node=args.gpus_per_node, + world_size=args.num_gpus, + in_place=True, + instances=1, + protocol="Simple", + auto_sync=True, + num_threads_per_block=args.num_threads_per_block, + reuse_resources=False, + use_double_scratch_buffer=False, + min_message_size=args.min_message_size, + max_message_size=args.max_message_size, + ) + + prog = allgather_mrc(spec) + print(prog.to_json()) \ No newline at end of file diff --git a/src/ext/collectives/algorithm_collection_builder.cc b/src/ext/collectives/algorithm_collection_builder.cc index 1ea4adb0..be0b919d 100644 --- a/src/ext/collectives/algorithm_collection_builder.cc +++ b/src/ext/collectives/algorithm_collection_builder.cc @@ -114,7 +114,8 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultDslAlgorithms(int ra static const std::vector defaultAlgoConfigs = { {"allreduce_2nodes_1K_64K.json", "allreduce", 8, 16, {{"default", 1}}}, {"allreduce_2nodes_64K_2M.json", "allreduce", 8, 16, {{"default", 1}}}, - {"allgather_mrc.json", "allgather", 1, 2, {{"default", 1}}}}; + {"allgather_mrc_2_nodes.json", "allgather", 1, 2, {{"default", 1}}}, + {"allgather_mrc_4_nodes.json", "allgather", 1, 4, {{"default", 1}}}}; AlgorithmCollection collection; static auto generateFileId = [](const std::string& input) { diff --git a/src/ext/nccl/algorithm_selector.cc b/src/ext/nccl/algorithm_selector.cc index dc3b84fe..0b9592d7 100644 --- a/src/ext/nccl/algorithm_selector.cc +++ b/src/ext/nccl/algorithm_selector.cc @@ -53,7 +53,6 @@ bool matchExecutionPlan(std::shared_ptr algo, const CollectiveRequ bool maxSizeMatch = effectiveSize <= algo->messageRange().second; bool result = worldSizeMatch && ranksPerNodeMatch && collectiveMatch && bufferModeMatch && minSizeMatch && maxSizeMatch; - return true; return result; }