Integrate MSCCL++ DSL to torch workload (#620)

Provides two integration ways for MSCCL++ DSL.
1. Integrate with customized communication group
2. Integrate with NCCL API

Introduce new Python APIs to make it work:
```python
mscclpp.compile # compile dsl to json based execution plan
mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery
mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size....
```
Fix #556

---------

Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Binyang Li
2025-10-29 15:39:00 -07:00
committed by GitHub
parent 9994f53cea
commit 5acac93dbc
48 changed files with 1438 additions and 277 deletions

View File

@@ -1,12 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
add_subdirectory(mscclpp)
add_subdirectory(csrc)
add_subdirectory(test)
add_custom_target(pytest_lib_copy ALL
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_CURRENT_BINARY_DIR}/mscclpp/_mscclpp.*.so
${CMAKE_CURRENT_BINARY_DIR}/csrc/_mscclpp.*.so
${CMAKE_CURRENT_SOURCE_DIR}/mscclpp
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_CURRENT_BINARY_DIR}/test/_ext.*.so

View File

@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/executor.hpp>
#include <mscclpp/gpu.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_executor(nb::module_& m) {
nb::enum_<DataType>(m, "DataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
nb::class_<ExecutionRequest>(m, "ExecutionRequest")
.def_ro("world_size", &ExecutionRequest::worldSize)
.def_ro("n_ranks_per_node", &ExecutionRequest::nRanksPerNode)
.def_prop_ro(
"input_buffer",
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.inputBuffer); })
.def_prop_ro(
"output_buffer",
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
.def_ro("message_size", &ExecutionRequest::messageSize)
.def_prop_ro("collective", [](ExecutionRequest& self) -> const std::string& { return self.collective; })
.def_prop_ro("hints", [](ExecutionRequest& self) { return self.hints; });
nb::class_<ExecutionPlanHandle>(m, "ExecutionPlanHandle")
.def_ro("id", &ExecutionPlanHandle::id)
.def_ro("constraint", &ExecutionPlanHandle::constraint)
.def_ro("plan", &ExecutionPlanHandle::plan)
.def_ro("tags", &ExecutionPlanHandle::tags)
.def_static("create", &ExecutionPlanHandle::create, nb::arg("id"), nb::arg("world_size"),
nb::arg("nranks_per_node"), nb::arg("plan"),
nb::arg("tags") = std::unordered_map<std::string, uint64_t>{});
nb::class_<ExecutionPlanHandle::Constraint>(m, "ExecutionPlanConstraint")
.def_ro("world_size", &ExecutionPlanHandle::Constraint::worldSize)
.def_ro("n_ranks_per_node", &ExecutionPlanHandle::Constraint::nRanksPerNode);
nb::class_<ExecutionPlanRegistry>(m, "ExecutionPlanRegistry")
.def_static("get_instance", &ExecutionPlanRegistry::getInstance)
.def("register_plan", &ExecutionPlanRegistry::registerPlan, nb::arg("planHandle"))
.def("get_plans", &ExecutionPlanRegistry::getPlans, nb::arg("collective"))
.def("get", &ExecutionPlanRegistry::get, nb::arg("id"))
.def("set_selector", &ExecutionPlanRegistry::setSelector, nb::arg("selector"))
.def("set_default_selector", &ExecutionPlanRegistry::setDefaultSelector, nb::arg("selector"))
.def("clear", &ExecutionPlanRegistry::clear);
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
.def_prop_ro("collective", [](const ExecutionPlan& self) -> std::string { return self.collective(); })
.def_prop_ro("min_message_size", [](const ExecutionPlan& self) -> size_t { return self.minMessageSize(); })
.def_prop_ro("max_message_size", [](const ExecutionPlan& self) -> size_t { return self.maxMessageSize(); });
nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
},
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
nb::arg("packet_type") = PacketType::LL16);
}

View File

@@ -3,8 +3,20 @@
"""MSCCL++ Python API."""
import atexit
from dataclasses import dataclass
from functools import cached_property, wraps
import inspect
import json
import os
from pathlib import Path
from typing import Any
import warnings
from blake3 import blake3
from mscclpp.language.program import CollectiveProgram
from mscclpp.language.utils import AlgoSpec
from functools import wraps
from mscclpp._version import __version__, __commit_id__
@@ -49,11 +61,14 @@ from ._mscclpp import (
DataType,
Executor,
ExecutionPlan,
ExecutionPlanConstraint,
PacketType,
RawGpuBuffer,
env,
is_nvls_supported,
npkit,
ExecutionPlanHandle as _ExecutionPlanHandle,
ExecutionPlanRegistry as _ExecutionPlanRegistry,
)
__all__ = [
@@ -79,6 +94,9 @@ __all__ = [
"Executor",
"ExecutionPlan",
"PacketType",
"RawGpuBuffer",
"env",
"version",
"is_nvls_supported",
"alloc_shared_physical_cuda",
"npkit",
@@ -87,10 +105,6 @@ __all__ = [
"version",
"get_include",
"get_lib",
### Deprecated ###
"ProxyChannel",
"SmChannel",
"SmDevice2DeviceSemaphore",
]
@@ -119,16 +133,193 @@ def deprecated(new_cls):
return decorator
@deprecated(PortChannel)
class ProxyChannel(PortChannel):
pass
class ExecutionPlanHandle:
def __init__(self, handle: _ExecutionPlanHandle):
self._handle = handle
@cached_property
def id(self) -> int:
return self._handle.id
@cached_property
def tags(self) -> set:
return frozenset(self._handle.tags)
@cached_property
def plan(self) -> ExecutionPlan:
return self._handle.plan
@cached_property
def constraints(self) -> ExecutionPlanConstraint:
return self._handle.constraints
@deprecated(MemoryChannel)
class SmChannel(MemoryChannel):
pass
@dataclass(frozen=True)
class ExecutionRequest:
collective: str
world_size: int
n_ranks_per_node: int
send_buffer: int
recv_buffer: int
message_size: int
hints: dict
@deprecated(MemoryDevice2DeviceSemaphore)
class SmDevice2DeviceSemaphore(MemoryDevice2DeviceSemaphore):
pass
class ExecutionPlanRegistry:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ExecutionPlanRegistry, cls).__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, "_initialized"):
self._registry = _ExecutionPlanRegistry.get_instance()
self._id_map = {}
self._collective_map = {}
self._selector = None
self._initialized = True
def register_plan(self, plan: ExecutionPlanHandle):
self._id_map[plan.id] = plan
if plan.plan.collective not in self._collective_map:
self._collective_map[plan.plan.collective] = []
self._collective_map[plan.plan.collective].append(plan)
return self._instance._registry.register_plan(plan._handle)
def set_selector(self, selector):
self._selector = selector
self._instance._registry.set_selector(selector)
def set_default_selector(self, selector):
self._selector = selector
self._instance._registry.set_default_selector(selector)
def get(self, id: str) -> ExecutionPlanHandle:
return self._id_map.get(id, None)
def select(
self,
collective: str,
world_size: int,
n_ranks_per_node: int,
send_buffer: int,
recv_buffer: int,
message_size: int,
hints: dict = {},
) -> ExecutionPlanHandle:
if self._selector is None or collective not in self._collective_map:
return None
req = ExecutionRequest(
collective=collective,
world_size=world_size,
n_ranks_per_node=n_ranks_per_node,
send_buffer=send_buffer,
recv_buffer=recv_buffer,
message_size=message_size,
hints=hints,
)
return self._selector(self._collective_map[collective], req)
@classmethod
def reset_instance(cls):
if cls._instance is not None:
cls._instance._registry.clear()
cls._instance._id_map = {}
cls._instance._collective_map = {}
cls._instance._selector = None
cls._instance = None
atexit.register(ExecutionPlanRegistry.reset_instance)
_execution_plan_registry = ExecutionPlanRegistry()
def _stable_json_bytes(obj: Any) -> bytes:
return json.dumps(
obj,
sort_keys=True,
ensure_ascii=False,
separators=(",", ":"),
).encode("utf-8")
def compile(
algo,
algo_spec: AlgoSpec,
rank: int,
**kwargs,
) -> ExecutionPlanHandle:
"""Compile a MSCCL++ program from a high-level algorithm description.
Args:
algo: The high-level algorithm description (e.g., a function or class).
algo_spec (AlgoSpec): Algorithm specification containing collective type,
world size, ranks per node, instances, protocol, and other configuration.
rank (int): The rank of the current process.
**kwargs: Additional keyword arguments passed to the algorithm function.
Returns:
ExecutionPlanHandle: The compiled execution plan handle.
Raises:
ValueError: If the 'algo' argument is not callable.
"""
if not callable(algo):
raise ValueError("The 'algo' argument must be a callable (e.g., a function or class).")
prog: CollectiveProgram = algo(
algo_spec,
**kwargs,
)
source = inspect.getsource(algo)
source_hash = blake3(source.encode("utf-8")).hexdigest()
plan_id = blake3(
_stable_json_bytes(
{
"version": __version__,
"algo_name": algo_spec.name,
"collective": algo_spec.collective.name,
"tags": sorted(algo_spec.tags.items()),
"source_hash": source_hash,
"envs": {
"nranks_per_node": algo_spec.nranks_per_node,
"world_size": algo_spec.world_size,
"instances": algo_spec.instances,
"protocol": algo_spec.protocol,
},
}
)
).hexdigest()
plan_handle = _execution_plan_registry.get(plan_id)
if plan_handle is not None:
return plan_handle
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp")
os.makedirs(plan_dir, exist_ok=True)
filename = f"{plan_id}.json"
plan_path = os.path.join(plan_dir, filename)
tmp_path = plan_path + f".tmp.{os.getpid()}"
if not os.path.exists(plan_path):
try:
# TODO (binyli): Each rank could generate its own execution plan separately. Doesn't need to generate whole plan.
with open(tmp_path, "w") as f:
prog.post_process_operations()
f.write(prog.to_json(indent=None, separators=(",", ":"), ensure_ascii=False))
f.flush()
os.fsync(f.fileno())
if not os.path.exists(plan_path):
os.rename(tmp_path, plan_path)
else:
os.remove(tmp_path)
except Exception:
Path(plan_path).unlink(missing_ok=True)
execution_plan = ExecutionPlan(plan_path, rank)
handle = _ExecutionPlanHandle.create(
id=plan_id,
world_size=algo_spec.world_size,
nranks_per_node=algo_spec.nranks_per_node,
plan=execution_plan,
tags=algo_spec.tags,
)
return ExecutionPlanHandle(handle)

View File

@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import shutil
import argparse
from pathlib import Path
from mscclpp.language import default_algos as def_algo
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec
default_algo_configs = [
{
"filename": "allreduce_2nodes.json",
"function": def_algo.allreduce_2nodes,
"spec": AlgoSpec(
name="allreduce_2nodes",
collective=AllReduce(16, 1, True),
nranks_per_node=8,
world_size=16,
in_place=True,
instances=1,
protocol="LL",
auto_sync=False,
num_threads_per_block=1024,
reuse_resources=True,
use_double_scratch_buffer=True,
min_message_size=1 << 10,
max_message_size=2 << 20,
tags={"default": 1},
),
"additional_args": [4],
}
]
def create_default_plans():
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp_default")
plan_path = Path(plan_dir)
if plan_path.exists():
shutil.rmtree(plan_path)
plan_path.mkdir(parents=True)
for config in default_algo_configs:
filename = config["filename"]
func = config["function"]
spec = config["spec"]
additional_args = config.get("additional_args", [])
plan_path = os.path.join(plan_dir, filename)
try:
if additional_args:
prog = func(spec, *additional_args)
else:
prog = func(spec)
with open(plan_path, "w", encoding="utf-8") as f:
f.write(prog.to_json())
f.flush()
except Exception as e:
print(f"Error creating plan for {spec.name}: {e}")
continue
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--install", action="store_true", help="flag to install default plans")
args = parser.parse_args()
if args.install:
create_default_plans()
if __name__ == "__main__":
main()

View File

@@ -1,43 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/executor.hpp>
#include <mscclpp/gpu.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_executor(nb::module_& m) {
nb::enum_<DataType>(m, "DataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
.def("name", &ExecutionPlan::name)
.def("collective", &ExecutionPlan::collective)
.def("min_message_size", &ExecutionPlan::minMessageSize)
.def("max_message_size", &ExecutionPlan::maxMessageSize);
nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
},
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
nb::arg("packet_type") = PacketType::LL16);
}

View File

@@ -26,6 +26,11 @@ class MemoryChannel:
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new MemoryChannel.
@@ -453,6 +458,11 @@ class PortChannel:
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new PortChannel.
@@ -741,6 +751,11 @@ class SwitchChannel:
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, rank_list: List[int], buffer_type: BufferType):
"""Initialize a new SwitchChannel.

View File

@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.default_algos.allreduce_2nodes import allreduce_2nodes
__all__ = ["allreduce_2nodes"]

View File

@@ -7,7 +7,7 @@ This implements a hierarchical AllReduce: intra-node allreduce followed by
inter-node exchange and final intra-node allreduce.
"""
import argparse
from mscclpp.language.utils import AlgoSpec
from mscclpp.language.channel import *
from mscclpp.language.rank import *
from mscclpp.language.general import *
@@ -15,9 +15,7 @@ from mscclpp.language.program import *
from mscclpp.language.collectives import *
def allreduce_example(
program_name, gpus_per_node, thread_block_group_size, num_threads_per_block, min_message_size, max_message_size
):
def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram:
"""
Implements a multi-node AllReduce using a hierarchical approach:
1. Intra-node allreduce
@@ -26,24 +24,11 @@ def allreduce_example(
"""
# Configuration constants
num_nodes = 2
gpus_per_node = spec.nranks_per_node
total_gpus = num_nodes * gpus_per_node
chunks_per_loop = 1
packets_per_gpu = 2 # Each GPU handles 2 data packets
packets_per_gpu = 2
# Initialize collective operation
collective = AllReduce(total_gpus, chunks_per_loop, True)
with CollectiveProgram(
program_name,
collective,
total_gpus,
protocol="LL",
num_threads_per_block=num_threads_per_block,
reuse_resources=False,
use_double_scratch_buffer=True,
min_message_size=min_message_size,
max_message_size=max_message_size,
):
with CollectiveProgram.from_spec(spec) as prog:
# Initialize communication channels and buffers
intra_node_memory_channels = {}
inter_node_port_channels = {}
@@ -175,25 +160,4 @@ def allreduce_example(
tb_group=thread_block_group,
)
print(JSON())
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
parser.add_argument("--tbg_size", type=int, help="number of thread blocks in the thread block group")
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
parser.add_argument("--max_message_size", type=int, default=2 * 2**20, help="maximum message size")
args = parser.parse_args()
allreduce_example(
args.name,
args.gpus_per_node,
args.tbg_size,
args.num_threads_per_block,
args.min_message_size,
args.max_message_size,
)
return prog

View File

@@ -16,5 +16,4 @@ def JSON():
str: A JSON string representation of the current MSCCL++ program,
including all ranks, operations, channels, and configuration.
"""
get_program().post_process_operations()
return get_program().to_json()

View File

@@ -6,6 +6,7 @@ from mscclpp.language.internal.optimizer import *
from mscclpp.language.internal.buffer_access import *
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import List
@dataclass
@@ -88,7 +89,7 @@ class ThreadBlock:
@dataclass
class Channel:
channel_type: ChannelType
channel_ids: list[int] = field(default_factory=list)
channel_ids: List[int] = field(default_factory=list)
def to_dict(self) -> dict:
return {"channel_type": self.channel_type.value, "channel_ids": self.channel_ids}
@@ -96,7 +97,7 @@ class ThreadBlock:
@dataclass
class RemoteBuffer:
access_channel_type: ChannelType
remote_buffer_ids: list[int] = field(default_factory=list)
remote_buffer_ids: List[int] = field(default_factory=list)
def to_dict(self) -> dict:
return {

View File

@@ -4,7 +4,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Set
from collections import defaultdict
class SyncType(Enum):

View File

@@ -3,8 +3,12 @@
from mscclpp.language.collectives import Collective
from mscclpp.language.internal.globals import set_program
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType, ReplicationPolicy
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType
from mscclpp.language.internal.gpu import Gpu
from mscclpp.language.channel import *
from mscclpp.language.rank import Semaphore
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
from typing import List
import json
@@ -108,6 +112,55 @@ class CollectiveProgram:
self.loop_context = None
@classmethod
def from_spec(cls, spec: AlgoSpec):
"""Initialize a new CollectiveProgram from an algorithm specification.
This constructor provides an alternative way to create a CollectiveProgram
using an AlgoSpec object, which contains the complete algorithm specification
including collective instance, protocol parameters, and optimization settings.
The collective operation is directly provided through the spec's collective attribute.
Args:
spec (AlgoSpec): Algorithm specification containing all program parameters
and configuration settings, including a Collective instance.
Raises:
AssertionError: If protocol is not "Simple" or "LL".
Example:
>>> from mscclpp.language.utils import AlgoSpec
>>> from mscclpp.language.collectives import AllReduce
>>> collective = AllReduce(num_ranks=4, chunk_factor=1, inplace=False)
>>> spec = AlgoSpec(
... name="my_allreduce",
... collective=collective,
... world_size=4,
... instances=1,
... protocol="Simple",
... in_place=False
... )
>>> with CollectiveProgram.from_spec(spec) as prog:
... # Define communication operations
... pass
"""
return cls(
spec.name,
spec.collective,
spec.world_size,
instances=spec.instances,
protocol=spec.protocol,
instr_fusion=spec.instr_fusion,
auto_sync=spec.auto_sync,
replication_policy=spec.replication_policy,
reuse_resources=spec.reuse_resources,
num_threads_per_block=spec.num_threads_per_block,
use_double_scratch_buffer=spec.use_double_scratch_buffer,
buffer_alignment=spec.buffer_alignment,
min_message_size=spec.min_message_size,
max_message_size=spec.max_message_size,
)
def __enter__(self):
"""Enter the program context and set this as the active program.
@@ -115,6 +168,7 @@ class CollectiveProgram:
this program as the active program in the global context.
"""
set_program(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the program context and clear the active program.
@@ -122,6 +176,10 @@ class CollectiveProgram:
This method is called when exiting the 'with' statement and removes
this program from the global context.
"""
MemoryChannel.reset()
PortChannel.reset()
SwitchChannel.reset()
Semaphore.reset()
set_program(None)
def add_channel(self, channel):
@@ -175,7 +233,8 @@ class CollectiveProgram:
raise RuntimeError("Nested Pipelines are not Supported.")
self.loop_context = loop_context
def to_json(self):
def to_json(self, indent=2, **kwargs):
self.post_process_operations()
json_obj = {
"name": self.name,
"collective": self.collective.name,
@@ -190,4 +249,4 @@ class CollectiveProgram:
"max_message_size": self.max_message_size,
}
return json.dumps(json_obj, indent=2)
return json.dumps(json_obj, indent=indent, **kwargs)

View File

@@ -367,6 +367,11 @@ class Semaphore:
_semaphore_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all semaphore counts."""
cls._semaphore_counts.clear()
def __init__(self, rank: int, initial_value: int):
"""Initialize a new Semaphore.

View File

@@ -40,7 +40,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
input_buffer = rank.get_input_buffer()
for peer in range(gpu_size):
if peer != gpu:
channels[(peer, gpu)].put_packet(
channels[(peer, gpu)].put_packets(
scratch_buffer[peer][gpu : gpu + 1], input_buffer[peer : peer + 1], 0
)
@@ -55,7 +55,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
rank.reduce(input_buffer[gpu : gpu + 1], chunks, 0, packet=True)
for peer in range(gpu_size):
if peer != gpu:
channels[(peer, gpu)].put_packet(
channels[(peer, gpu)].put_packets(
scratch_buffer[peer][gpu_size + gpu : gpu_size + gpu + 1], input_buffer[gpu : gpu + 1], 0
)
@@ -65,7 +65,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
input_buffer = rank.get_input_buffer()
for peer in range(gpu_size):
if peer != gpu:
rank.unpack_packet(
rank.unpack_packets(
input_buffer[peer : peer + 1], scratch_buffer[gpu][gpu_size + peer : gpu_size + peer + 1], 0
)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from enum import Enum
from dataclasses import dataclass, field
from mscclpp.language.collectives import Collective
class ReplicationPolicy(Enum):
interleaved = "interleaved"
none = "none"
def __str__(self):
return self.value
@dataclass(frozen=True)
class AlgoSpec:
name: str
collective: Collective
nranks_per_node: int
world_size: int
in_place: bool
instances: int
protocol: str
instr_fusion: bool = True
auto_sync: bool = True
replication_policy: ReplicationPolicy = ReplicationPolicy.interleaved
reuse_resources: bool = False
num_threads_per_block: int = 1024
use_double_scratch_buffer: bool = False
buffer_alignment: int = 16
min_message_size: int = 0
max_message_size: int = 2**64 - 1
tags: dict = field(default_factory=dict)

View File

@@ -6,3 +6,4 @@ pytest
numpy
matplotlib
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
blake3

View File

@@ -5,4 +5,5 @@ netifaces
pytest
numpy
matplotlib
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
blake3

View File

@@ -187,7 +187,7 @@ def main(
if npkit_dump_dir != "":
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_plan_path, mscclpp_group.my_rank)
collective = execution_plan.collective()
collective = execution_plan.collective
dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(