Files
mscclpp/python/mscclpp/__main__.py
Caio Rocha 1b9f335ddd wip
2026-05-05 21:19:23 +00:00

141 lines
4.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import shutil
import argparse
from pathlib import Path
from mscclpp import default_algos as def_algo
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec
default_algo_configs = [
{
"filename": "allreduce_2nodes_1K_64K.json",
"function": def_algo.allreduce_2nodes,
"spec": AlgoSpec(
name="allreduce_2nodes_1K_64K",
collective=AllReduce(16, 1, True),
nranks_per_node=8,
world_size=16,
in_place=True,
instances=1,
protocol="LL",
auto_sync=False,
num_threads_per_block=1024,
reuse_resources=True,
use_double_scratch_buffer=True,
min_message_size=1 << 10,
max_message_size=64 << 10,
tags={"default": 1},
),
"additional_kwargs": {"thread_block_group_size": 1},
},
{
"filename": "allreduce_2nodes_128K_2M.json",
"function": def_algo.allreduce_2nodes,
"spec": AlgoSpec(
name="allreduce_2nodes_128K_2M",
collective=AllReduce(16, 1, True),
nranks_per_node=8,
world_size=16,
in_place=True,
instances=1,
protocol="LL",
auto_sync=False,
num_threads_per_block=1024,
reuse_resources=True,
use_double_scratch_buffer=True,
min_message_size=128 << 10,
max_message_size=2 << 20,
tags={"default": 1},
),
"additional_kwargs": {"thread_block_group_size": 4},
},
{
"filename": "allreduce_4nodes_1K_8M.json",
"function": def_algo.allreduce_multi_nodes,
"spec": AlgoSpec(
name="allreduce_4nodes_1K_8M",
collective=AllReduce(32, 1, True),
nranks_per_node=8,
world_size=32,
in_place=True,
instances=1,
protocol="LL",
auto_sync=False,
num_threads_per_block=1024,
reuse_resources=True,
use_double_scratch_buffer=True,
min_message_size=1 << 10,
max_message_size=8 << 20,
tags={"default": 1},
),
"additional_kwargs": {"thread_block_group_size": 8},
},
{
"filename": "allreduce_8nodes_1K_8M.json",
"function": def_algo.allreduce_multi_nodes,
"spec": AlgoSpec(
name="allreduce_8nodes_1K_8M",
collective=AllReduce(64, 1, True),
nranks_per_node=8,
world_size=64,
in_place=True,
instances=1,
protocol="LL",
auto_sync=False,
num_threads_per_block=1024,
reuse_resources=True,
use_double_scratch_buffer=True,
min_message_size=1 << 10,
max_message_size=8 << 20,
tags={"default": 1},
),
"additional_kwargs": {"thread_block_group_size": 1},
},
]
def create_default_plans():
plan_dir = os.path.join(os.environ.get("MSCCLPP_CACHE_DIR", Path.home() / ".cache/mscclpp"), "default")
plan_path = Path(plan_dir)
if plan_path.exists():
shutil.rmtree(plan_path)
plan_path.mkdir(parents=True)
for config in default_algo_configs:
filename = config["filename"]
func = config["function"]
spec = config["spec"]
additional_kwargs = config.get("additional_kwargs", {})
plan_path = os.path.join(plan_dir, filename)
try:
if additional_kwargs:
prog = func(spec, **additional_kwargs)
else:
prog = func(spec)
with open(plan_path, "w", encoding="utf-8") as f:
f.write(prog.to_json())
f.flush()
except Exception as e:
print(f"Error creating plan for {spec.name}: {e}")
continue
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--install", action="store_true", help="flag to install default plans")
args = parser.parse_args()
if args.install:
create_default_plans()
if __name__ == "__main__":
main()