diff --git a/examples/torch-integration/dsl_with_nccl_api.py b/examples/torch-integration/dsl_with_nccl_api.py index 975d3749..5a4dd1c4 100644 --- a/examples/torch-integration/dsl_with_nccl_api.py +++ b/examples/torch-integration/dsl_with_nccl_api.py @@ -1,19 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# LD_PRELOAD=/build/lib/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py +# LD_PRELOAD=/build/lib/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py import os from typing import Any, Dict import torch, torch.distributed as dist -import mscclpp +import mscclpp.ext from mscclpp.language.collectives import AllReduce from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType from mscclpp.language.program import CollectiveProgram from mscclpp.language.rank import Rank +from mscclpp.language.utils import AlgoSpec -def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram: +def allreduce_nvls(spec: AlgoSpec) -> CollectiveProgram: gpu_size = spec.world_size with CollectiveProgram.from_spec(spec) as program: # Creating Channels @@ -63,8 +64,8 @@ def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram: return program -def setup_plan(algo_collection_builder: mscclpp.AlgorithmCollectionBuilder, rank: int, world_size: int): - spec = mscclpp.AlgoSpec( +def setup_plan(algo_collection_builder: mscclpp.ext.AlgorithmCollectionBuilder, rank: int, world_size: int): + spec = AlgoSpec( name="allreduce_nvls", collective=AllReduce(8, 1, True), nranks_per_node=8, @@ -94,10 +95,10 @@ def init_dist(): rank = int(os.environ["RANK"]) world = int(os.environ["WORLD_SIZE"]) local = int(os.environ["LOCAL_RANK"]) - algorithm_collection_builder = mscclpp.AlgorithmCollectionBuilder() + algorithm_collection_builder = mscclpp.ext.AlgorithmCollectionBuilder() setup_plan(algorithm_collection_builder, rank, world) algorithm_collection_builder.set_algorithm_selector(selector) - dist.init_process_group(backend="nccl", device_id=local) + dist.init_process_group(backend="nccl", device_id=torch.device("cuda", local)) return rank, world, local diff --git a/python/csrc/ext/algorithm_collection_builder_py.cpp b/python/csrc/ext/algorithm_collection_builder_py.cpp index be7f944e..4a3563d9 100644 --- a/python/csrc/ext/algorithm_collection_builder_py.cpp +++ b/python/csrc/ext/algorithm_collection_builder_py.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include