// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include #include #include #include #include #include namespace nb = nanobind; using namespace mscclpp; extern void register_error(nb::module_& m); extern void register_proxy_channel(nb::module_& m); extern void register_sm_channel(nb::module_& m); extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); extern void register_config(nb::module_& m); extern void register_utils(nb::module_& m); template void def_nonblocking_future(nb::handle& m, const std::string& typestr) { std::string pyclass_name = std::string("NonblockingFuture") + typestr; nb::class_>(m, pyclass_name.c_str()) .def("ready", &NonblockingFuture::ready) .def("get", &NonblockingFuture::get); } void register_core(nb::module_& m) { nb::class_(m, "Bootstrap") .def("get_rank", &Bootstrap::getRank) .def("get_n_ranks", &Bootstrap::getNranks) .def( "send", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { void* data = reinterpret_cast(ptr); self->send(data, size, peer, tag); }, nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag")) .def( "recv", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { void* data = reinterpret_cast(ptr); self->recv(data, size, peer, tag); }, nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag")) .def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size")) .def("barrier", &Bootstrap::barrier) .def("send", (void (Bootstrap::*)(const std::vector&, int, int)) & Bootstrap::send, nb::arg("data"), nb::arg("peer"), nb::arg("tag")) .def("recv", (void (Bootstrap::*)(std::vector&, int, int)) & Bootstrap::recv, nb::arg("data"), nb::arg("peer"), nb::arg("tag")); nb::class_(m, "UniqueId"); nb::class_(m, "TcpBootstrap") .def(nb::init(), "Do not use this constructor. Use create instead.") .def_static( "create", [](int rank, int nRanks) { return std::make_shared(rank, nRanks); }, nb::arg("rank"), nb::arg("nRanks")) .def("create_unique_id", &TcpBootstrap::createUniqueId) .def("get_unique_id", &TcpBootstrap::getUniqueId) .def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId")) .def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize, nb::arg("ifIpPortTrio")); nb::enum_(m, "Transport") .value("Unknown", Transport::Unknown) .value("CudaIpc", Transport::CudaIpc) .value("IB0", Transport::IB0) .value("IB1", Transport::IB1) .value("IB2", Transport::IB2) .value("IB3", Transport::IB3) .value("IB4", Transport::IB4) .value("IB5", Transport::IB5) .value("IB6", Transport::IB6) .value("IB7", Transport::IB7) .value("NumTransports", Transport::NumTransports); nb::class_(m, "TransportFlags") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("transport")) .def("has", &TransportFlags::has, nb::arg("transport")) .def("none", &TransportFlags::none) .def("any", &TransportFlags::any) .def("all", &TransportFlags::all) .def("count", &TransportFlags::count) .def(nb::self |= nb::self) .def(nb::self | nb::self) .def(nb::self | Transport()) .def(nb::self &= nb::self) .def(nb::self & nb::self) .def(nb::self & Transport()) .def(nb::self ^= nb::self) .def(nb::self ^ nb::self) .def(nb::self ^ Transport()) .def(~nb::self) .def(nb::self == nb::self) .def(nb::self != nb::self); nb::class_(m, "RegisteredMemory") .def(nb::init<>()) .def("data", &RegisteredMemory::data) .def("size", &RegisteredMemory::size) .def("rank", &RegisteredMemory::rank) .def("transports", &RegisteredMemory::transports) .def("serialize", &RegisteredMemory::serialize) .def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data")); nb::class_(m, "Connection") .def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"), nb::arg("size")) .def( "update_and_sync", [](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) { self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue); }, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue")) .def("flush", &Connection::flush) .def("remote_rank", &Connection::remoteRank) .def("tag", &Connection::tag) .def("transport", &Connection::transport) .def("remote_transport", &Connection::remoteTransport); def_nonblocking_future(m, "RegisteredMemory"); nb::class_(m, "Communicator") .def(nb::init>(), nb::arg("bootstrap")) .def("bootstrap", &Communicator::bootstrap) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { return self->registerMemory((void*)ptr, size, transports); }, nb::arg("ptr"), nb::arg("size"), nb::arg("transports")) .def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag")) .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("transport")) .def("setup", &Communicator::setup); } NB_MODULE(mscclpp, m) { register_error(m); register_proxy_channel(m); register_sm_channel(m); register_fifo(m); register_semaphore(m); register_config(m); register_utils(m); register_core(m); }