mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 14:54:51 +00:00
Integrate MSCCL++ DSL to torch workload (#620)
Provides two integration ways for MSCCL++ DSL. 1. Integrate with customized communication group 2. Integrate with NCCL API Introduce new Python APIs to make it work: ```python mscclpp.compile # compile dsl to json based execution plan mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size.... ``` Fix #556 --------- Co-authored-by: Caio Rocha <caiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
|
||||
FetchContent_MakeAvailable(nanobind)
|
||||
|
||||
FetchContent_Declare(dlpack
|
||||
GIT_REPOSITORY https://github.com/dmlc/dlpack.git
|
||||
GIT_TAG 5c210da409e7f1e51ddf445134a4376fdbd70d7d
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(dlpack)
|
||||
if(NOT dlpack_POPULATED)
|
||||
FetchContent_Populate(dlpack)
|
||||
# Add dlpack subdirectory but exclude it from installation
|
||||
add_subdirectory(${dlpack_SOURCE_DIR} ${dlpack_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
|
||||
nanobind_add_module(mscclpp_py ${SOURCES})
|
||||
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp_static ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
@@ -3,8 +3,20 @@
|
||||
|
||||
"""MSCCL++ Python API."""
|
||||
|
||||
import atexit
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, wraps
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
from functools import wraps
|
||||
from mscclpp._version import __version__, __commit_id__
|
||||
|
||||
@@ -49,11 +61,14 @@ from ._mscclpp import (
|
||||
DataType,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
ExecutionPlanConstraint,
|
||||
PacketType,
|
||||
RawGpuBuffer,
|
||||
env,
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
ExecutionPlanHandle as _ExecutionPlanHandle,
|
||||
ExecutionPlanRegistry as _ExecutionPlanRegistry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -79,6 +94,9 @@ __all__ = [
|
||||
"Executor",
|
||||
"ExecutionPlan",
|
||||
"PacketType",
|
||||
"RawGpuBuffer",
|
||||
"env",
|
||||
"version",
|
||||
"is_nvls_supported",
|
||||
"alloc_shared_physical_cuda",
|
||||
"npkit",
|
||||
@@ -87,10 +105,6 @@ __all__ = [
|
||||
"version",
|
||||
"get_include",
|
||||
"get_lib",
|
||||
### Deprecated ###
|
||||
"ProxyChannel",
|
||||
"SmChannel",
|
||||
"SmDevice2DeviceSemaphore",
|
||||
]
|
||||
|
||||
|
||||
@@ -119,16 +133,193 @@ def deprecated(new_cls):
|
||||
return decorator
|
||||
|
||||
|
||||
@deprecated(PortChannel)
|
||||
class ProxyChannel(PortChannel):
|
||||
pass
|
||||
class ExecutionPlanHandle:
|
||||
|
||||
def __init__(self, handle: _ExecutionPlanHandle):
|
||||
self._handle = handle
|
||||
|
||||
@cached_property
|
||||
def id(self) -> int:
|
||||
return self._handle.id
|
||||
|
||||
@cached_property
|
||||
def tags(self) -> set:
|
||||
return frozenset(self._handle.tags)
|
||||
|
||||
@cached_property
|
||||
def plan(self) -> ExecutionPlan:
|
||||
return self._handle.plan
|
||||
|
||||
@cached_property
|
||||
def constraints(self) -> ExecutionPlanConstraint:
|
||||
return self._handle.constraints
|
||||
|
||||
|
||||
@deprecated(MemoryChannel)
|
||||
class SmChannel(MemoryChannel):
|
||||
pass
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionRequest:
|
||||
collective: str
|
||||
world_size: int
|
||||
n_ranks_per_node: int
|
||||
send_buffer: int
|
||||
recv_buffer: int
|
||||
message_size: int
|
||||
hints: dict
|
||||
|
||||
|
||||
@deprecated(MemoryDevice2DeviceSemaphore)
|
||||
class SmDevice2DeviceSemaphore(MemoryDevice2DeviceSemaphore):
|
||||
pass
|
||||
class ExecutionPlanRegistry:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ExecutionPlanRegistry, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._registry = _ExecutionPlanRegistry.get_instance()
|
||||
self._id_map = {}
|
||||
self._collective_map = {}
|
||||
self._selector = None
|
||||
self._initialized = True
|
||||
|
||||
def register_plan(self, plan: ExecutionPlanHandle):
|
||||
self._id_map[plan.id] = plan
|
||||
if plan.plan.collective not in self._collective_map:
|
||||
self._collective_map[plan.plan.collective] = []
|
||||
self._collective_map[plan.plan.collective].append(plan)
|
||||
return self._instance._registry.register_plan(plan._handle)
|
||||
|
||||
def set_selector(self, selector):
|
||||
self._selector = selector
|
||||
self._instance._registry.set_selector(selector)
|
||||
|
||||
def set_default_selector(self, selector):
|
||||
self._selector = selector
|
||||
self._instance._registry.set_default_selector(selector)
|
||||
|
||||
def get(self, id: str) -> ExecutionPlanHandle:
|
||||
return self._id_map.get(id, None)
|
||||
|
||||
def select(
|
||||
self,
|
||||
collective: str,
|
||||
world_size: int,
|
||||
n_ranks_per_node: int,
|
||||
send_buffer: int,
|
||||
recv_buffer: int,
|
||||
message_size: int,
|
||||
hints: dict = {},
|
||||
) -> ExecutionPlanHandle:
|
||||
if self._selector is None or collective not in self._collective_map:
|
||||
return None
|
||||
req = ExecutionRequest(
|
||||
collective=collective,
|
||||
world_size=world_size,
|
||||
n_ranks_per_node=n_ranks_per_node,
|
||||
send_buffer=send_buffer,
|
||||
recv_buffer=recv_buffer,
|
||||
message_size=message_size,
|
||||
hints=hints,
|
||||
)
|
||||
return self._selector(self._collective_map[collective], req)
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
if cls._instance is not None:
|
||||
cls._instance._registry.clear()
|
||||
cls._instance._id_map = {}
|
||||
cls._instance._collective_map = {}
|
||||
cls._instance._selector = None
|
||||
cls._instance = None
|
||||
|
||||
|
||||
atexit.register(ExecutionPlanRegistry.reset_instance)
|
||||
|
||||
_execution_plan_registry = ExecutionPlanRegistry()
|
||||
|
||||
|
||||
def _stable_json_bytes(obj: Any) -> bytes:
|
||||
return json.dumps(
|
||||
obj,
|
||||
sort_keys=True,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def compile(
|
||||
algo,
|
||||
algo_spec: AlgoSpec,
|
||||
rank: int,
|
||||
**kwargs,
|
||||
) -> ExecutionPlanHandle:
|
||||
"""Compile a MSCCL++ program from a high-level algorithm description.
|
||||
Args:
|
||||
algo: The high-level algorithm description (e.g., a function or class).
|
||||
algo_spec (AlgoSpec): Algorithm specification containing collective type,
|
||||
world size, ranks per node, instances, protocol, and other configuration.
|
||||
rank (int): The rank of the current process.
|
||||
**kwargs: Additional keyword arguments passed to the algorithm function.
|
||||
Returns:
|
||||
ExecutionPlanHandle: The compiled execution plan handle.
|
||||
Raises:
|
||||
ValueError: If the 'algo' argument is not callable.
|
||||
"""
|
||||
if not callable(algo):
|
||||
raise ValueError("The 'algo' argument must be a callable (e.g., a function or class).")
|
||||
prog: CollectiveProgram = algo(
|
||||
algo_spec,
|
||||
**kwargs,
|
||||
)
|
||||
source = inspect.getsource(algo)
|
||||
|
||||
source_hash = blake3(source.encode("utf-8")).hexdigest()
|
||||
plan_id = blake3(
|
||||
_stable_json_bytes(
|
||||
{
|
||||
"version": __version__,
|
||||
"algo_name": algo_spec.name,
|
||||
"collective": algo_spec.collective.name,
|
||||
"tags": sorted(algo_spec.tags.items()),
|
||||
"source_hash": source_hash,
|
||||
"envs": {
|
||||
"nranks_per_node": algo_spec.nranks_per_node,
|
||||
"world_size": algo_spec.world_size,
|
||||
"instances": algo_spec.instances,
|
||||
"protocol": algo_spec.protocol,
|
||||
},
|
||||
}
|
||||
)
|
||||
).hexdigest()
|
||||
plan_handle = _execution_plan_registry.get(plan_id)
|
||||
if plan_handle is not None:
|
||||
return plan_handle
|
||||
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp")
|
||||
os.makedirs(plan_dir, exist_ok=True)
|
||||
filename = f"{plan_id}.json"
|
||||
plan_path = os.path.join(plan_dir, filename)
|
||||
tmp_path = plan_path + f".tmp.{os.getpid()}"
|
||||
if not os.path.exists(plan_path):
|
||||
try:
|
||||
# TODO (binyli): Each rank could generate its own execution plan separately. Doesn't need to generate whole plan.
|
||||
with open(tmp_path, "w") as f:
|
||||
prog.post_process_operations()
|
||||
f.write(prog.to_json(indent=None, separators=(",", ":"), ensure_ascii=False))
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
if not os.path.exists(plan_path):
|
||||
os.rename(tmp_path, plan_path)
|
||||
else:
|
||||
os.remove(tmp_path)
|
||||
except Exception:
|
||||
Path(plan_path).unlink(missing_ok=True)
|
||||
execution_plan = ExecutionPlan(plan_path, rank)
|
||||
handle = _ExecutionPlanHandle.create(
|
||||
id=plan_id,
|
||||
world_size=algo_spec.world_size,
|
||||
nranks_per_node=algo_spec.nranks_per_node,
|
||||
plan=execution_plan,
|
||||
tags=algo_spec.tags,
|
||||
)
|
||||
return ExecutionPlanHandle(handle)
|
||||
|
||||
77
python/mscclpp/__main__.py
Normal file
77
python/mscclpp/__main__.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from mscclpp.language import default_algos as def_algo
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
|
||||
default_algo_configs = [
|
||||
{
|
||||
"filename": "allreduce_2nodes.json",
|
||||
"function": def_algo.allreduce_2nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_2nodes",
|
||||
collective=AllReduce(16, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=16,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="LL",
|
||||
auto_sync=False,
|
||||
num_threads_per_block=1024,
|
||||
reuse_resources=True,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=1 << 10,
|
||||
max_message_size=2 << 20,
|
||||
tags={"default": 1},
|
||||
),
|
||||
"additional_args": [4],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def create_default_plans():
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp_default")
|
||||
plan_path = Path(plan_dir)
|
||||
if plan_path.exists():
|
||||
shutil.rmtree(plan_path)
|
||||
plan_path.mkdir(parents=True)
|
||||
|
||||
for config in default_algo_configs:
|
||||
filename = config["filename"]
|
||||
func = config["function"]
|
||||
spec = config["spec"]
|
||||
additional_args = config.get("additional_args", [])
|
||||
plan_path = os.path.join(plan_dir, filename)
|
||||
|
||||
try:
|
||||
if additional_args:
|
||||
prog = func(spec, *additional_args)
|
||||
else:
|
||||
prog = func(spec)
|
||||
|
||||
with open(plan_path, "w", encoding="utf-8") as f:
|
||||
f.write(prog.to_json())
|
||||
f.flush()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating plan for {spec.name}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--install", action="store_true", help="flag to install default plans")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.install:
|
||||
create_default_plans()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,266 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
extern void register_env(nb::module_& m);
|
||||
extern void register_error(nb::module_& m);
|
||||
extern void register_port_channel(nb::module_& m);
|
||||
extern void register_memory_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
extern void register_nvls(nb::module_& m);
|
||||
extern void register_executor(nb::module_& m);
|
||||
extern void register_npkit(nb::module_& m);
|
||||
extern void register_gpu_utils(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("shared_future_") + typestr;
|
||||
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
m.def("version", &version);
|
||||
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->send(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def(
|
||||
"recv",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->recv(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
||||
.def("barrier", &Bootstrap::barrier)
|
||||
.def("send", static_cast<void (Bootstrap::*)(const std::vector<char>&, int, int)>(&Bootstrap::send),
|
||||
nb::arg("data"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId");
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
nb::arg("nRanks"))
|
||||
.def_static("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("unique_id"), nb::arg("timeout_sec") = 30)
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("IB0", Transport::IB0)
|
||||
.value("IB1", Transport::IB1)
|
||||
.value("IB2", Transport::IB2)
|
||||
.value("IB3", Transport::IB3)
|
||||
.value("IB4", Transport::IB4)
|
||||
.value("IB5", Transport::IB5)
|
||||
.value("IB6", Transport::IB6)
|
||||
.value("IB7", Transport::IB7)
|
||||
.value("NumTransports", Transport::NumTransports);
|
||||
|
||||
nb::class_<TransportFlags>(m, "TransportFlags")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def("has", &TransportFlags::has, nb::arg("transport"))
|
||||
.def("none", &TransportFlags::none)
|
||||
.def("any", &TransportFlags::any)
|
||||
.def("all", &TransportFlags::all)
|
||||
.def("count", &TransportFlags::count)
|
||||
.def(nb::self | nb::self)
|
||||
.def(nb::self | Transport())
|
||||
.def(nb::self & nb::self)
|
||||
.def(nb::self & Transport())
|
||||
.def(nb::self ^ nb::self)
|
||||
.def(nb::self ^ Transport())
|
||||
.def(
|
||||
"__ior__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs |= rhs; }, nb::is_operator())
|
||||
.def(
|
||||
"__iand__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs &= rhs; }, nb::is_operator())
|
||||
.def(
|
||||
"__ixor__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs ^= rhs; }, nb::is_operator())
|
||||
.def(~nb::self)
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::enum_<DeviceType>(m, "DeviceType")
|
||||
.value("Unknown", DeviceType::Unknown)
|
||||
.value("CPU", DeviceType::CPU)
|
||||
.value("GPU", DeviceType::GPU);
|
||||
|
||||
nb::class_<Device>(m, "Device")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
||||
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
||||
.def_rw("type", &Device::type)
|
||||
.def_rw("id", &Device::id)
|
||||
.def("__str__", [](const Device& self) { return std::to_string(self); });
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int, int, int>(), nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
|
||||
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
||||
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
||||
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend)
|
||||
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
||||
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
||||
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
.def("size", &RegisteredMemory::size)
|
||||
.def("transports", &RegisteredMemory::transports)
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("config", &Endpoint::config)
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("device", &Endpoint::device)
|
||||
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
.def(
|
||||
"update_and_sync",
|
||||
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dst_offset"), nb::arg("src"), nb::arg("new_value"))
|
||||
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("timeout_usec") = (int64_t)3e7)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport)
|
||||
.def("context", &Connection::context)
|
||||
.def("local_device", &Connection::localDevice)
|
||||
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("max_write_queue_size") = -1, nb::arg("ib") = EndpointConfig::Ib{})
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("device", &EndpointConfig::device)
|
||||
.def_rw("ib", &EndpointConfig::ib)
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
.def_static("create", &Context::create)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Context* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Semaphore>(m, "Semaphore")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
|
||||
.def("connection", &Semaphore::connection)
|
||||
.def("local_memory", &Semaphore::localMemory)
|
||||
.def("remote_memory", &Semaphore::remoteMemory);
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int,
|
||||
int)>(&Communicator::connect),
|
||||
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def(
|
||||
"connect_on_setup",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
},
|
||||
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
|
||||
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("local_flag"), nb::arg("remote_rank"),
|
||||
nb::arg("tag") = 0)
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", [](Communicator*) {});
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
register_env(m);
|
||||
register_error(m);
|
||||
register_port_channel(m);
|
||||
register_memory_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
register_nvls(m);
|
||||
register_executor(m);
|
||||
register_npkit(m);
|
||||
register_gpu_utils(m);
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/env.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_env(nb::module_& m) {
|
||||
nb::class_<Env>(m, "Env")
|
||||
.def_ro("debug", &Env::debug)
|
||||
.def_ro("debug_subsys", &Env::debugSubsys)
|
||||
.def_ro("debug_file", &Env::debugFile)
|
||||
.def_ro("hca_devices", &Env::hcaDevices)
|
||||
.def_ro("hostid", &Env::hostid)
|
||||
.def_ro("socket_family", &Env::socketFamily)
|
||||
.def_ro("socket_ifname", &Env::socketIfname)
|
||||
.def_ro("comm_id", &Env::commId)
|
||||
.def_ro("execution_plan_dir", &Env::executionPlanDir)
|
||||
.def_ro("npkit_dump_dir", &Env::npkitDumpDir)
|
||||
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream);
|
||||
|
||||
m.def("env", &env);
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
|
||||
nb::register_exception_translator( \
|
||||
[](const std::exception_ptr &p, void *payload) { \
|
||||
try { \
|
||||
std::rethrow_exception(p); \
|
||||
} catch (const name_ &e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
|
||||
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
|
||||
} \
|
||||
}, \
|
||||
m.attr(#name_).ptr());
|
||||
|
||||
void register_error(nb::module_ &m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
.value("InvalidUsage", ErrorCode::InvalidUsage)
|
||||
.value("Timeout", ErrorCode::Timeout)
|
||||
.value("Aborted", ErrorCode::Aborted)
|
||||
.value("ExecutorError", ErrorCode::ExecutorError);
|
||||
|
||||
nb::exception<BaseError>(m, "BaseError");
|
||||
REGISTER_EXCEPTION_TRANSLATOR(BaseError);
|
||||
|
||||
nb::exception<Error>(m, "Error", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(Error);
|
||||
|
||||
nb::exception<SysError>(m, "SysError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(SysError);
|
||||
|
||||
nb::exception<CudaError>(m, "CudaError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(CudaError);
|
||||
|
||||
nb::exception<CuError>(m, "CuError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(CuError);
|
||||
|
||||
nb::exception<IbError>(m, "IbError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(IbError);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def("name", &ExecutionPlan::name)
|
||||
.def("collective", &ExecutionPlan::collective)
|
||||
.def("min_message_size", &ExecutionPlan::minMessageSize)
|
||||
.def("max_message_size", &ExecutionPlan::maxMessageSize);
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
|
||||
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packet_type") = PacketType::LL16);
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/fifo.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger")
|
||||
.def_prop_rw(
|
||||
"fst", [](const ProxyTrigger& self) { return self.fst; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.fst = v; })
|
||||
.def_prop_rw(
|
||||
"snd", [](const ProxyTrigger& self) { return self.snd; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.snd = v; });
|
||||
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail", &FifoDeviceHandle::tail)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
.def_rw("size", &FifoDeviceHandle::size)
|
||||
.def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
.def("size", &Fifo::size)
|
||||
.def("device_handle", &Fifo::deviceHandle);
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <dlpack/dlpack.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
constexpr int BYTE_BITS = 8;
|
||||
|
||||
static DLDeviceType getDeviceType() {
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
return kDLROCM;
|
||||
#else
|
||||
return kDLCUDA;
|
||||
#endif
|
||||
}
|
||||
|
||||
static DLDataType getDlType(std::string type) {
|
||||
if (type == "torch.float32") {
|
||||
return DLDataType{kDLFloat, 32, 1};
|
||||
} else if (type == "torch.int32") {
|
||||
return DLDataType{kDLInt, 32, 1};
|
||||
} else if (type == "torch.uint32") {
|
||||
return DLDataType{kDLUInt, 32, 1};
|
||||
} else if (type == "torch.bfloat16") {
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
} else if (type == "torch.float16") {
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
} else {
|
||||
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
|
||||
}
|
||||
}
|
||||
|
||||
static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::vector<int64_t>& shape,
|
||||
std::vector<int64_t>& strides) {
|
||||
DLDataType dtype = getDlType(dataType);
|
||||
int64_t* tensorShape = shape.size() > 0 ? new int64_t[shape.size()] : new int64_t[1];
|
||||
int64_t* tensorStrides = strides.size() > 0 ? new int64_t[strides.size()] : nullptr;
|
||||
if (shape.size() == 0) {
|
||||
tensorShape[0] = (int64_t)(buffer.nelems() / ((dtype.bits * dtype.lanes + 7) / BYTE_BITS));
|
||||
} else {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
tensorShape[i] = shape[i];
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < strides.size(); ++i) {
|
||||
tensorStrides[i] = strides[i];
|
||||
}
|
||||
|
||||
DLManagedTensor* dlManagedTensor = new DLManagedTensor();
|
||||
dlManagedTensor->dl_tensor.data = buffer.data();
|
||||
dlManagedTensor->dl_tensor.device.device_type = getDeviceType();
|
||||
dlManagedTensor->dl_tensor.device.device_id = buffer.deviceId();
|
||||
dlManagedTensor->dl_tensor.ndim = shape.size() == 0 ? 1 : shape.size();
|
||||
dlManagedTensor->dl_tensor.strides = tensorStrides;
|
||||
dlManagedTensor->dl_tensor.shape = tensorShape;
|
||||
dlManagedTensor->dl_tensor.byte_offset = 0;
|
||||
dlManagedTensor->dl_tensor.dtype = dtype;
|
||||
dlManagedTensor->manager_ctx = new GpuBuffer<char>(buffer);
|
||||
dlManagedTensor->deleter = [](DLManagedTensor* self) {
|
||||
delete static_cast<GpuBuffer<char>*>(self->manager_ctx);
|
||||
self->manager_ctx = nullptr;
|
||||
self->dl_tensor.data = nullptr;
|
||||
if (self->dl_tensor.shape != nullptr) {
|
||||
delete[] self->dl_tensor.shape;
|
||||
self->dl_tensor.shape = nullptr;
|
||||
if (self->dl_tensor.strides) {
|
||||
delete[] self->dl_tensor.strides;
|
||||
self->dl_tensor.strides = nullptr;
|
||||
}
|
||||
}
|
||||
delete self;
|
||||
};
|
||||
|
||||
PyObject* dlCapsule = PyCapsule_New(static_cast<void*>(dlManagedTensor), "dltensor", [](PyObject* capsule) {
|
||||
if (PyCapsule_IsValid(capsule, "used_dltensor")) {
|
||||
return;
|
||||
}
|
||||
if (!PyCapsule_IsValid(capsule, "dltensor")) {
|
||||
return;
|
||||
}
|
||||
DLManagedTensor* managedTensor = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule, "dltensor"));
|
||||
if (managedTensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (managedTensor->deleter) {
|
||||
managedTensor->deleter(managedTensor);
|
||||
}
|
||||
});
|
||||
return nb::steal<nb::capsule>(dlCapsule);
|
||||
}
|
||||
|
||||
void register_gpu_utils(nb::module_& m) {
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
|
||||
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
|
||||
.def(nb::init<size_t>(), nb::arg("nelems"))
|
||||
.def("nelems", &GpuBuffer<char>::nelems)
|
||||
.def("bytes", &GpuBuffer<char>::bytes)
|
||||
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
.def("device_id", &GpuBuffer<char>::deviceId)
|
||||
.def(
|
||||
"to_dlpack",
|
||||
[](GpuBuffer<char>& self, std::string dataType, std::vector<int64_t> shape, std::vector<int64_t> strides) {
|
||||
return toDlpack(self, dataType, shape, strides);
|
||||
},
|
||||
nb::arg("data_type"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
|
||||
}
|
||||
@@ -26,6 +26,11 @@ class MemoryChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, dst_rank: int, src_rank: int):
|
||||
"""Initialize a new MemoryChannel.
|
||||
|
||||
@@ -453,6 +458,11 @@ class PortChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, dst_rank: int, src_rank: int):
|
||||
"""Initialize a new PortChannel.
|
||||
|
||||
@@ -741,6 +751,11 @@ class SwitchChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, rank_list: List[int], buffer_type: BufferType):
|
||||
"""Initialize a new SwitchChannel.
|
||||
|
||||
|
||||
6
python/mscclpp/language/default_algos/__init__.py
Normal file
6
python/mscclpp/language/default_algos/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.default_algos.allreduce_2nodes import allreduce_2nodes
|
||||
|
||||
__all__ = ["allreduce_2nodes"]
|
||||
@@ -7,7 +7,7 @@ This implements a hierarchical AllReduce: intra-node allreduce followed by
|
||||
inter-node exchange and final intra-node allreduce.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import *
|
||||
from mscclpp.language.general import *
|
||||
@@ -15,9 +15,7 @@ from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def allreduce_example(
|
||||
program_name, gpus_per_node, thread_block_group_size, num_threads_per_block, min_message_size, max_message_size
|
||||
):
|
||||
def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram:
|
||||
"""
|
||||
Implements a multi-node AllReduce using a hierarchical approach:
|
||||
1. Intra-node allreduce
|
||||
@@ -26,24 +24,11 @@ def allreduce_example(
|
||||
"""
|
||||
# Configuration constants
|
||||
num_nodes = 2
|
||||
gpus_per_node = spec.nranks_per_node
|
||||
total_gpus = num_nodes * gpus_per_node
|
||||
chunks_per_loop = 1
|
||||
packets_per_gpu = 2 # Each GPU handles 2 data packets
|
||||
packets_per_gpu = 2
|
||||
|
||||
# Initialize collective operation
|
||||
collective = AllReduce(total_gpus, chunks_per_loop, True)
|
||||
|
||||
with CollectiveProgram(
|
||||
program_name,
|
||||
collective,
|
||||
total_gpus,
|
||||
protocol="LL",
|
||||
num_threads_per_block=num_threads_per_block,
|
||||
reuse_resources=False,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=min_message_size,
|
||||
max_message_size=max_message_size,
|
||||
):
|
||||
with CollectiveProgram.from_spec(spec) as prog:
|
||||
# Initialize communication channels and buffers
|
||||
intra_node_memory_channels = {}
|
||||
inter_node_port_channels = {}
|
||||
@@ -175,25 +160,4 @@ def allreduce_example(
|
||||
tb_group=thread_block_group,
|
||||
)
|
||||
|
||||
print(JSON())
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
|
||||
parser.add_argument("--tbg_size", type=int, help="number of thread blocks in the thread block group")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2 * 2**20, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_example(
|
||||
args.name,
|
||||
args.gpus_per_node,
|
||||
args.tbg_size,
|
||||
args.num_threads_per_block,
|
||||
args.min_message_size,
|
||||
args.max_message_size,
|
||||
)
|
||||
return prog
|
||||
@@ -16,5 +16,4 @@ def JSON():
|
||||
str: A JSON string representation of the current MSCCL++ program,
|
||||
including all ranks, operations, channels, and configuration.
|
||||
"""
|
||||
get_program().post_process_operations()
|
||||
return get_program().to_json()
|
||||
|
||||
@@ -6,6 +6,7 @@ from mscclpp.language.internal.optimizer import *
|
||||
from mscclpp.language.internal.buffer_access import *
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,7 +89,7 @@ class ThreadBlock:
|
||||
@dataclass
|
||||
class Channel:
|
||||
channel_type: ChannelType
|
||||
channel_ids: list[int] = field(default_factory=list)
|
||||
channel_ids: List[int] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"channel_type": self.channel_type.value, "channel_ids": self.channel_ids}
|
||||
@@ -96,7 +97,7 @@ class ThreadBlock:
|
||||
@dataclass
|
||||
class RemoteBuffer:
|
||||
access_channel_type: ChannelType
|
||||
remote_buffer_ids: list[int] = field(default_factory=list)
|
||||
remote_buffer_ids: List[int] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Set
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class SyncType(Enum):
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
|
||||
from mscclpp.language.collectives import Collective
|
||||
from mscclpp.language.internal.globals import set_program
|
||||
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType, ReplicationPolicy
|
||||
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType
|
||||
from mscclpp.language.internal.gpu import Gpu
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import Semaphore
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
@@ -108,6 +112,55 @@ class CollectiveProgram:
|
||||
|
||||
self.loop_context = None
|
||||
|
||||
@classmethod
|
||||
def from_spec(cls, spec: AlgoSpec):
|
||||
"""Initialize a new CollectiveProgram from an algorithm specification.
|
||||
|
||||
This constructor provides an alternative way to create a CollectiveProgram
|
||||
using an AlgoSpec object, which contains the complete algorithm specification
|
||||
including collective instance, protocol parameters, and optimization settings.
|
||||
The collective operation is directly provided through the spec's collective attribute.
|
||||
|
||||
Args:
|
||||
spec (AlgoSpec): Algorithm specification containing all program parameters
|
||||
and configuration settings, including a Collective instance.
|
||||
|
||||
Raises:
|
||||
AssertionError: If protocol is not "Simple" or "LL".
|
||||
|
||||
Example:
|
||||
>>> from mscclpp.language.utils import AlgoSpec
|
||||
>>> from mscclpp.language.collectives import AllReduce
|
||||
>>> collective = AllReduce(num_ranks=4, chunk_factor=1, inplace=False)
|
||||
>>> spec = AlgoSpec(
|
||||
... name="my_allreduce",
|
||||
... collective=collective,
|
||||
... world_size=4,
|
||||
... instances=1,
|
||||
... protocol="Simple",
|
||||
... in_place=False
|
||||
... )
|
||||
>>> with CollectiveProgram.from_spec(spec) as prog:
|
||||
... # Define communication operations
|
||||
... pass
|
||||
"""
|
||||
return cls(
|
||||
spec.name,
|
||||
spec.collective,
|
||||
spec.world_size,
|
||||
instances=spec.instances,
|
||||
protocol=spec.protocol,
|
||||
instr_fusion=spec.instr_fusion,
|
||||
auto_sync=spec.auto_sync,
|
||||
replication_policy=spec.replication_policy,
|
||||
reuse_resources=spec.reuse_resources,
|
||||
num_threads_per_block=spec.num_threads_per_block,
|
||||
use_double_scratch_buffer=spec.use_double_scratch_buffer,
|
||||
buffer_alignment=spec.buffer_alignment,
|
||||
min_message_size=spec.min_message_size,
|
||||
max_message_size=spec.max_message_size,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the program context and set this as the active program.
|
||||
|
||||
@@ -115,6 +168,7 @@ class CollectiveProgram:
|
||||
this program as the active program in the global context.
|
||||
"""
|
||||
set_program(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the program context and clear the active program.
|
||||
@@ -122,6 +176,10 @@ class CollectiveProgram:
|
||||
This method is called when exiting the 'with' statement and removes
|
||||
this program from the global context.
|
||||
"""
|
||||
MemoryChannel.reset()
|
||||
PortChannel.reset()
|
||||
SwitchChannel.reset()
|
||||
Semaphore.reset()
|
||||
set_program(None)
|
||||
|
||||
def add_channel(self, channel):
|
||||
@@ -175,7 +233,8 @@ class CollectiveProgram:
|
||||
raise RuntimeError("Nested Pipelines are not Supported.")
|
||||
self.loop_context = loop_context
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self, indent=2, **kwargs):
|
||||
self.post_process_operations()
|
||||
json_obj = {
|
||||
"name": self.name,
|
||||
"collective": self.collective.name,
|
||||
@@ -190,4 +249,4 @@ class CollectiveProgram:
|
||||
"max_message_size": self.max_message_size,
|
||||
}
|
||||
|
||||
return json.dumps(json_obj, indent=2)
|
||||
return json.dumps(json_obj, indent=indent, **kwargs)
|
||||
|
||||
@@ -367,6 +367,11 @@ class Semaphore:
|
||||
|
||||
_semaphore_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all semaphore counts."""
|
||||
cls._semaphore_counts.clear()
|
||||
|
||||
def __init__(self, rank: int, initial_value: int):
|
||||
"""Initialize a new Semaphore.
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
input_buffer = rank.get_input_buffer()
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)].put_packet(
|
||||
channels[(peer, gpu)].put_packets(
|
||||
scratch_buffer[peer][gpu : gpu + 1], input_buffer[peer : peer + 1], 0
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
rank.reduce(input_buffer[gpu : gpu + 1], chunks, 0, packet=True)
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)].put_packet(
|
||||
channels[(peer, gpu)].put_packets(
|
||||
scratch_buffer[peer][gpu_size + gpu : gpu_size + gpu + 1], input_buffer[gpu : gpu + 1], 0
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
input_buffer = rank.get_input_buffer()
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
rank.unpack_packet(
|
||||
rank.unpack_packets(
|
||||
input_buffer[peer : peer + 1], scratch_buffer[gpu][gpu_size + peer : gpu_size + peer + 1], 0
|
||||
)
|
||||
|
||||
|
||||
35
python/mscclpp/language/utils.py
Normal file
35
python/mscclpp/language/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from mscclpp.language.collectives import Collective
|
||||
|
||||
|
||||
class ReplicationPolicy(Enum):
|
||||
interleaved = "interleaved"
|
||||
none = "none"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgoSpec:
|
||||
name: str
|
||||
collective: Collective
|
||||
nranks_per_node: int
|
||||
world_size: int
|
||||
in_place: bool
|
||||
instances: int
|
||||
protocol: str
|
||||
instr_fusion: bool = True
|
||||
auto_sync: bool = True
|
||||
replication_policy: ReplicationPolicy = ReplicationPolicy.interleaved
|
||||
reuse_resources: bool = False
|
||||
num_threads_per_block: int = 1024
|
||||
use_double_scratch_buffer: bool = False
|
||||
buffer_alignment: int = 16
|
||||
min_message_size: int = 0
|
||||
max_message_size: int = 2**64 - 1
|
||||
tags: dict = field(default_factory=dict)
|
||||
@@ -1,48 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
|
||||
.def("__init__",
|
||||
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
|
||||
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
|
||||
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
|
||||
})
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
.def_rw("src_", &MemoryChannel::DeviceHandle::src_)
|
||||
.def_rw("packetBuffer_", &MemoryChannel::DeviceHandle::packetBuffer_)
|
||||
.def_prop_ro("raw", [](const MemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
void register_npkit(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions");
|
||||
sub_m.def("init", &NpKit::Init);
|
||||
sub_m.def("dump", &NpKit::Dump);
|
||||
sub_m.def("shutdown", &NpKit::Shutdown);
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mscclpp {
|
||||
int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_port_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<int>(), nb::arg("fifo_size") = DEFAULT_FIFO_SIZE)
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", static_cast<SemaphoreId (ProxyService::*)(const Semaphore&)>(&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_semaphore",
|
||||
static_cast<SemaphoreId (ProxyService::*)(std::shared_ptr<Host2DeviceSemaphore>)>(
|
||||
&ProxyService::addSemaphore),
|
||||
nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &BasePortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &BasePortChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const BasePortChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &PortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &PortChannel::DeviceHandle::fifo_)
|
||||
.def_rw("src_", &PortChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &PortChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const PortChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_token", &Host2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("expected_inbound_token", &Host2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("poll", &Host2HostSemaphore::poll)
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
|
||||
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
|
||||
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_nvls(nb::module_& m) {
|
||||
nb::class_<SwitchChannel>(m, "SwitchChannel")
|
||||
.def("get_device_ptr", [](SwitchChannel* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &SwitchChannel::deviceHandle);
|
||||
|
||||
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
|
||||
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
|
||||
.def_rw("size", &SwitchChannel::DeviceHandle::bufferSize)
|
||||
.def_prop_ro("raw", [](const SwitchChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"))
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
|
||||
nb::arg("buffer_size"));
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_utils(nb::module_& m) { m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim")); }
|
||||
Reference in New Issue
Block a user