mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-03-24 00:57:47 +00:00
330 lines
15 KiB
C++
330 lines
15 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT license.
|
|
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/operators.h>
|
|
#include <nanobind/stl/shared_ptr.h>
|
|
#include <nanobind/stl/string.h>
|
|
#include <nanobind/stl/vector.h>
|
|
|
|
#include <mscclpp/core.hpp>
|
|
#include <sstream>
|
|
|
|
namespace nb = nanobind;
|
|
using namespace mscclpp;
|
|
|
|
extern void register_env(nb::module_& m);
|
|
extern void register_error(nb::module_& m);
|
|
extern void register_port_channel(nb::module_& m);
|
|
extern void register_memory_channel(nb::module_& m);
|
|
extern void register_fifo(nb::module_& m);
|
|
extern void register_semaphore(nb::module_& m);
|
|
extern void register_utils(nb::module_& m);
|
|
extern void register_numa(nb::module_& m);
|
|
extern void register_nvls(nb::module_& m);
|
|
extern void register_executor(nb::module_& m);
|
|
extern void register_npkit(nb::module_& m);
|
|
extern void register_gpu_utils(nb::module_& m);
|
|
extern void register_algorithm(nb::module_& m);
|
|
|
|
// ext
|
|
extern void register_algorithm_collection_builder(nb::module_& m);
|
|
|
|
template <typename T>
|
|
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
|
std::string pyclass_name = std::string("CppSharedFuture_") + typestr;
|
|
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
|
|
}
|
|
|
|
void register_core(nb::module_& m) {
|
|
m.def("version", &version);
|
|
|
|
nb::enum_<DataType>(m, "CppDataType")
|
|
.value("int32", DataType::INT32)
|
|
.value("uint32", DataType::UINT32)
|
|
.value("float16", DataType::FLOAT16)
|
|
.value("float32", DataType::FLOAT32)
|
|
.value("bfloat16", DataType::BFLOAT16)
|
|
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
|
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
|
.value("uint8", DataType::UINT8);
|
|
|
|
nb::class_<Bootstrap>(m, "CppBootstrap")
|
|
.def("get_rank", &Bootstrap::getRank)
|
|
.def("get_n_ranks", &Bootstrap::getNranks)
|
|
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
|
.def(
|
|
"send",
|
|
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
|
void* data = reinterpret_cast<void*>(ptr);
|
|
self->send(data, size, peer, tag);
|
|
},
|
|
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
|
.def(
|
|
"recv",
|
|
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
|
void* data = reinterpret_cast<void*>(ptr);
|
|
self->recv(data, size, peer, tag);
|
|
},
|
|
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
|
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
|
.def("barrier", &Bootstrap::barrier)
|
|
.def("send", static_cast<void (Bootstrap::*)(const std::vector<char>&, int, int)>(&Bootstrap::send),
|
|
nb::arg("data"), nb::arg("peer"), nb::arg("tag"))
|
|
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
|
|
nb::arg("peer"), nb::arg("tag"));
|
|
|
|
nb::class_<UniqueId>(m, "CppUniqueId")
|
|
.def(nb::init<>())
|
|
.def("__setstate__",
|
|
[](UniqueId& self, nb::bytes b) {
|
|
if (nb::len(b) != UniqueIdBytes) throw std::runtime_error("Invalid UniqueId byte size");
|
|
::memcpy(self.data(), b.c_str(), UniqueIdBytes);
|
|
})
|
|
.def("__getstate__",
|
|
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
|
|
|
|
nb::class_<TcpBootstrap, Bootstrap>(m, "CppTcpBootstrap")
|
|
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
|
.def_static(
|
|
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
|
nb::arg("nRanks"))
|
|
.def_static("create_unique_id", &TcpBootstrap::createUniqueId)
|
|
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
|
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
|
|
nb::call_guard<nb::gil_scoped_release>(), nb::arg("unique_id"), nb::arg("timeout_sec") = 30)
|
|
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
|
|
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
|
|
|
|
nb::enum_<Transport>(m, "CppTransport")
|
|
.value("Unknown", Transport::Unknown)
|
|
.value("CudaIpc", Transport::CudaIpc)
|
|
.value("IB0", Transport::IB0)
|
|
.value("IB1", Transport::IB1)
|
|
.value("IB2", Transport::IB2)
|
|
.value("IB3", Transport::IB3)
|
|
.value("IB4", Transport::IB4)
|
|
.value("IB5", Transport::IB5)
|
|
.value("IB6", Transport::IB6)
|
|
.value("IB7", Transport::IB7)
|
|
.value("NumTransports", Transport::NumTransports);
|
|
|
|
nb::class_<TransportFlags>(m, "CppTransportFlags")
|
|
.def(nb::init<>())
|
|
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
|
.def("has", &TransportFlags::has, nb::arg("transport"))
|
|
.def("none", &TransportFlags::none)
|
|
.def("any", &TransportFlags::any)
|
|
.def("all", &TransportFlags::all)
|
|
.def("count", &TransportFlags::count)
|
|
.def(nb::self | nb::self)
|
|
.def(nb::self | Transport())
|
|
.def(nb::self & nb::self)
|
|
.def(nb::self & Transport())
|
|
.def(nb::self ^ nb::self)
|
|
.def(nb::self ^ Transport())
|
|
.def(
|
|
"__ior__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs |= rhs; }, nb::is_operator())
|
|
.def(
|
|
"__iand__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs &= rhs; }, nb::is_operator())
|
|
.def(
|
|
"__ixor__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs ^= rhs; }, nb::is_operator())
|
|
.def(~nb::self)
|
|
.def(nb::self == nb::self)
|
|
.def(nb::self != nb::self);
|
|
|
|
nb::enum_<DeviceType>(m, "CppDeviceType")
|
|
.value("Unknown", DeviceType::Unknown)
|
|
.value("CPU", DeviceType::CPU)
|
|
.value("GPU", DeviceType::GPU);
|
|
|
|
nb::class_<Device>(m, "CppDevice")
|
|
.def(nb::init<>())
|
|
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
|
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
|
.def_rw("type", &Device::type)
|
|
.def_rw("id", &Device::id)
|
|
.def("__str__", [](const Device& self) {
|
|
std::stringstream ss;
|
|
ss << self;
|
|
return ss.str();
|
|
});
|
|
|
|
nb::enum_<EndpointConfig::Ib::Mode>(m, "CppIbMode")
|
|
.value("Default", EndpointConfig::Ib::Mode::Default)
|
|
.value("Host", EndpointConfig::Ib::Mode::Host)
|
|
.value("HostNoAtomic", EndpointConfig::Ib::Mode::HostNoAtomic);
|
|
|
|
nb::class_<EndpointConfig::Ib>(m, "CppEndpointConfigIb")
|
|
.def(nb::init<>())
|
|
.def(nb::init<int, int, int, int, int, int, int, int, EndpointConfig::Ib::Mode>(), nb::arg("device_index") = -1,
|
|
nb::arg("port") = EndpointConfig::Ib::DefaultPort,
|
|
nb::arg("gid_index") = EndpointConfig::Ib::DefaultGidIndex,
|
|
nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
|
|
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
|
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
|
nb::arg("max_recv_wr") = EndpointConfig::Ib::DefaultMaxRecvWr,
|
|
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend,
|
|
nb::arg("mode") = EndpointConfig::Ib::Mode::Default)
|
|
.def_rw("device_index", &EndpointConfig::Ib::deviceIndex)
|
|
.def_rw("port", &EndpointConfig::Ib::port)
|
|
.def_rw("gid_index", &EndpointConfig::Ib::gidIndex)
|
|
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
|
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
|
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
|
.def_rw("max_recv_wr", &EndpointConfig::Ib::maxRecvWr)
|
|
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend)
|
|
.def_rw("mode", &EndpointConfig::Ib::mode);
|
|
|
|
nb::class_<RegisteredMemory>(m, "CppRegisteredMemory")
|
|
.def(nb::init<>())
|
|
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
|
.def("size", &RegisteredMemory::size)
|
|
.def("transports", &RegisteredMemory::transports)
|
|
.def("serialize", &RegisteredMemory::serialize)
|
|
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
|
|
|
nb::class_<Endpoint>(m, "CppEndpoint")
|
|
.def("config", &Endpoint::config)
|
|
.def("transport", &Endpoint::transport)
|
|
.def("device", &Endpoint::device)
|
|
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
|
|
.def("serialize", &Endpoint::serialize)
|
|
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
|
|
|
nb::class_<Connection>(m, "CppConnection")
|
|
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
|
nb::arg("size"))
|
|
.def(
|
|
"update_and_sync",
|
|
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
|
|
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
|
},
|
|
nb::arg("dst"), nb::arg("dst_offset"), nb::arg("src"), nb::arg("new_value"))
|
|
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(),
|
|
nb::arg("timeout_usec") = (int64_t)3e7)
|
|
.def("transport", &Connection::transport)
|
|
.def("remote_transport", &Connection::remoteTransport)
|
|
.def("context", &Connection::context)
|
|
.def("local_device", &Connection::localDevice)
|
|
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
|
|
|
nb::class_<EndpointConfig>(m, "CppEndpointConfig")
|
|
.def(nb::init<>())
|
|
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
|
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
|
nb::arg("max_write_queue_size") = -1, nb::arg("ib") = EndpointConfig::Ib{})
|
|
.def_rw("transport", &EndpointConfig::transport)
|
|
.def_rw("device", &EndpointConfig::device)
|
|
.def_rw("ib", &EndpointConfig::ib)
|
|
.def_prop_rw(
|
|
"ib_device_index", [](EndpointConfig& self) { return self.ib.deviceIndex; },
|
|
[](EndpointConfig& self, int v) { self.ib.deviceIndex = v; })
|
|
.def_prop_rw(
|
|
"ib_port", [](EndpointConfig& self) { return self.ib.port; },
|
|
[](EndpointConfig& self, int v) { self.ib.port = v; })
|
|
.def_prop_rw(
|
|
"ib_gid_index", [](EndpointConfig& self) { return self.ib.gidIndex; },
|
|
[](EndpointConfig& self, int v) { self.ib.gidIndex = v; })
|
|
.def_prop_rw(
|
|
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
|
|
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
|
|
.def_prop_rw(
|
|
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
|
|
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
|
|
.def_prop_rw(
|
|
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
|
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
|
.def_prop_rw(
|
|
"ib_max_recv_wr", [](EndpointConfig& self) { return self.ib.maxRecvWr; },
|
|
[](EndpointConfig& self, int v) { self.ib.maxRecvWr = v; })
|
|
.def_prop_rw(
|
|
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
|
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
|
.def_prop_rw(
|
|
"ib_mode", [](EndpointConfig& self) { return self.ib.mode; },
|
|
[](EndpointConfig& self, EndpointConfig::Ib::Mode v) { self.ib.mode = v; })
|
|
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
|
|
|
nb::class_<Context>(m, "CppContext")
|
|
.def_static("create", &Context::create)
|
|
.def(
|
|
"register_memory",
|
|
[](Context* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
|
return self->registerMemory((void*)ptr, size, transports);
|
|
},
|
|
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
|
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
|
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
|
|
|
nb::class_<SemaphoreStub>(m, "CppSemaphoreStub")
|
|
.def(nb::init<const Connection&>(), nb::arg("connection"))
|
|
.def("memory", &SemaphoreStub::memory)
|
|
.def("serialize", &SemaphoreStub::serialize)
|
|
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
|
|
|
nb::class_<Semaphore>(m, "CppSemaphore")
|
|
.def(nb::init<>())
|
|
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
|
|
.def("connection", &Semaphore::connection)
|
|
.def("local_memory", &Semaphore::localMemory)
|
|
.def("remote_memory", &Semaphore::remoteMemory);
|
|
|
|
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
|
def_shared_future<Connection>(m, "Connection");
|
|
def_shared_future<Semaphore>(m, "Semaphore");
|
|
|
|
nb::class_<Communicator>(m, "CppCommunicator")
|
|
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
|
nb::arg("context") = nullptr)
|
|
.def("bootstrap", &Communicator::bootstrap)
|
|
.def("context", &Communicator::context)
|
|
.def(
|
|
"register_memory",
|
|
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
|
return self->registerMemory((void*)ptr, size, transports);
|
|
},
|
|
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
|
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
|
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
|
|
.def("connect",
|
|
static_cast<std::shared_future<Connection> (Communicator::*)(const EndpointConfig&, int, int)>(
|
|
&Communicator::connect),
|
|
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
|
.def(
|
|
"connect_on_setup",
|
|
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
|
return self->connect(std::move(localConfig), remoteRank, tag);
|
|
},
|
|
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
|
|
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
|
|
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
|
|
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("connection"), nb::arg("remote_rank"),
|
|
nb::arg("tag") = 0)
|
|
.def("remote_rank_of", &Communicator::remoteRankOf)
|
|
.def("tag_of", &Communicator::tagOf)
|
|
.def("setup", [](Communicator*) {});
|
|
}
|
|
|
|
NB_MODULE(_mscclpp, m) {
|
|
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
|
|
nb::set_leak_warnings(false);
|
|
#endif
|
|
register_env(m);
|
|
register_error(m);
|
|
register_port_channel(m);
|
|
register_memory_channel(m);
|
|
register_fifo(m);
|
|
register_semaphore(m);
|
|
register_utils(m);
|
|
register_core(m);
|
|
register_numa(m);
|
|
register_nvls(m);
|
|
register_executor(m);
|
|
register_npkit(m);
|
|
register_gpu_utils(m);
|
|
register_algorithm(m);
|
|
|
|
// ext
|
|
register_algorithm_collection_builder(m);
|
|
} |