// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #include #include #include #include #include #include #include namespace nb = nanobind; using namespace mscclpp; extern void register_env(nb::module_& m); extern void register_error(nb::module_& m); extern void register_port_channel(nb::module_& m); extern void register_memory_channel(nb::module_& m); extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); extern void register_utils(nb::module_& m); extern void register_numa(nb::module_& m); extern void register_nvls(nb::module_& m); extern void register_executor(nb::module_& m); extern void register_npkit(nb::module_& m); extern void register_gpu_utils(nb::module_& m); extern void register_algorithm(nb::module_& m); // ext extern void register_algorithm_collection_builder(nb::module_& m); template void def_shared_future(nb::handle& m, const std::string& typestr) { std::string pyclass_name = std::string("CppSharedFuture_") + typestr; nb::class_>(m, pyclass_name.c_str()).def("get", &std::shared_future::get); } void register_core(nb::module_& m) { m.def("version", &version); nb::enum_(m, "CppDataType") .value("int32", DataType::INT32) .value("uint32", DataType::UINT32) .value("float16", DataType::FLOAT16) .value("float32", DataType::FLOAT32) .value("bfloat16", DataType::BFLOAT16) .value("float8_e4m3", DataType::FLOAT8_E4M3) .value("float8_e5m2", DataType::FLOAT8_E5M2) .value("uint8", DataType::UINT8); nb::class_(m, "CppBootstrap") .def("get_rank", &Bootstrap::getRank) .def("get_n_ranks", &Bootstrap::getNranks) .def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode) .def( "send", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { void* data = reinterpret_cast(ptr); self->send(data, size, peer, tag); }, nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag")) .def( "recv", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { void* data = reinterpret_cast(ptr); self->recv(data, size, peer, tag); }, nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag")) .def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size")) .def("barrier", &Bootstrap::barrier) .def("send", static_cast&, int, int)>(&Bootstrap::send), nb::arg("data"), nb::arg("peer"), nb::arg("tag")) .def("recv", static_cast&, int, int)>(&Bootstrap::recv), nb::arg("data"), nb::arg("peer"), nb::arg("tag")); nb::class_(m, "CppUniqueId") .def(nb::init<>()) .def("__setstate__", [](UniqueId& self, nb::bytes b) { if (nb::len(b) != UniqueIdBytes) throw std::runtime_error("Invalid UniqueId byte size"); ::memcpy(self.data(), b.c_str(), UniqueIdBytes); }) .def("__getstate__", [](const UniqueId& self) { return nb::bytes(reinterpret_cast(self.data()), UniqueIdBytes); }); nb::class_(m, "CppTcpBootstrap") .def(nb::init(), "Do not use this constructor. Use create instead.") .def_static( "create", [](int rank, int nRanks) { return std::make_shared(rank, nRanks); }, nb::arg("rank"), nb::arg("nRanks")) .def_static("create_unique_id", &TcpBootstrap::createUniqueId) .def("get_unique_id", &TcpBootstrap::getUniqueId) .def("initialize", static_cast(&TcpBootstrap::initialize), nb::call_guard(), nb::arg("unique_id"), nb::arg("timeout_sec") = 30) .def("initialize", static_cast(&TcpBootstrap::initialize), nb::call_guard(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30); nb::enum_(m, "CppTransport") .value("Unknown", Transport::Unknown) .value("CudaIpc", Transport::CudaIpc) .value("IB0", Transport::IB0) .value("IB1", Transport::IB1) .value("IB2", Transport::IB2) .value("IB3", Transport::IB3) .value("IB4", Transport::IB4) .value("IB5", Transport::IB5) .value("IB6", Transport::IB6) .value("IB7", Transport::IB7) .value("NumTransports", Transport::NumTransports); nb::class_(m, "CppTransportFlags") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("transport")) .def("has", &TransportFlags::has, nb::arg("transport")) .def("none", &TransportFlags::none) .def("any", &TransportFlags::any) .def("all", &TransportFlags::all) .def("count", &TransportFlags::count) .def(nb::self | nb::self) .def(nb::self | Transport()) .def(nb::self & nb::self) .def(nb::self & Transport()) .def(nb::self ^ nb::self) .def(nb::self ^ Transport()) .def( "__ior__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs |= rhs; }, nb::is_operator()) .def( "__iand__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs &= rhs; }, nb::is_operator()) .def( "__ixor__", [](TransportFlags& lhs, const TransportFlags& rhs) { return lhs ^= rhs; }, nb::is_operator()) .def(~nb::self) .def(nb::self == nb::self) .def(nb::self != nb::self); nb::enum_(m, "CppDeviceType") .value("Unknown", DeviceType::Unknown) .value("CPU", DeviceType::CPU) .value("GPU", DeviceType::GPU); nb::class_(m, "CppDevice") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("type")) .def(nb::init(), nb::arg("type"), nb::arg("id") = -1) .def_rw("type", &Device::type) .def_rw("id", &Device::id) .def("__str__", [](const Device& self) { std::stringstream ss; ss << self; return ss.str(); }); nb::enum_(m, "CppIbMode") .value("Default", EndpointConfig::Ib::Mode::Default) .value("Host", EndpointConfig::Ib::Mode::Host) .value("HostNoAtomic", EndpointConfig::Ib::Mode::HostNoAtomic); nb::class_(m, "CppEndpointConfigIb") .def(nb::init<>()) .def(nb::init(), nb::arg("device_index") = -1, nb::arg("port") = EndpointConfig::Ib::DefaultPort, nb::arg("gid_index") = EndpointConfig::Ib::DefaultGidIndex, nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize, nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum, nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr, nb::arg("max_recv_wr") = EndpointConfig::Ib::DefaultMaxRecvWr, nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend, nb::arg("mode") = EndpointConfig::Ib::Mode::Default) .def_rw("device_index", &EndpointConfig::Ib::deviceIndex) .def_rw("port", &EndpointConfig::Ib::port) .def_rw("gid_index", &EndpointConfig::Ib::gidIndex) .def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize) .def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum) .def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr) .def_rw("max_recv_wr", &EndpointConfig::Ib::maxRecvWr) .def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend) .def_rw("mode", &EndpointConfig::Ib::mode); nb::class_(m, "CppRegisteredMemory") .def(nb::init<>()) .def("data", [](RegisteredMemory& self) { return reinterpret_cast(self.data()); }) .def("size", &RegisteredMemory::size) .def("transports", &RegisteredMemory::transports) .def("serialize", &RegisteredMemory::serialize) .def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data")); nb::class_(m, "CppEndpoint") .def("config", &Endpoint::config) .def("transport", &Endpoint::transport) .def("device", &Endpoint::device) .def("max_write_queue_size", &Endpoint::maxWriteQueueSize) .def("serialize", &Endpoint::serialize) .def_static("deserialize", &Endpoint::deserialize, nb::arg("data")); nb::class_(m, "CppConnection") .def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"), nb::arg("size")) .def( "update_and_sync", [](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) { self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue); }, nb::arg("dst"), nb::arg("dst_offset"), nb::arg("src"), nb::arg("new_value")) .def("flush", &Connection::flush, nb::call_guard(), nb::arg("timeout_usec") = (int64_t)3e7) .def("transport", &Connection::transport) .def("remote_transport", &Connection::remoteTransport) .def("context", &Connection::context) .def("local_device", &Connection::localDevice) .def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize); nb::class_(m, "CppEndpointConfig") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("transport")) .def(nb::init(), nb::arg("transport"), nb::arg("device"), nb::arg("max_write_queue_size") = -1, nb::arg("ib") = EndpointConfig::Ib{}) .def_rw("transport", &EndpointConfig::transport) .def_rw("device", &EndpointConfig::device) .def_rw("ib", &EndpointConfig::ib) .def_prop_rw( "ib_device_index", [](EndpointConfig& self) { return self.ib.deviceIndex; }, [](EndpointConfig& self, int v) { self.ib.deviceIndex = v; }) .def_prop_rw( "ib_port", [](EndpointConfig& self) { return self.ib.port; }, [](EndpointConfig& self, int v) { self.ib.port = v; }) .def_prop_rw( "ib_gid_index", [](EndpointConfig& self) { return self.ib.gidIndex; }, [](EndpointConfig& self, int v) { self.ib.gidIndex = v; }) .def_prop_rw( "ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; }, [](EndpointConfig& self, int v) { self.ib.maxCqSize = v; }) .def_prop_rw( "ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; }, [](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; }) .def_prop_rw( "ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; }, [](EndpointConfig& self, int v) { self.ib.maxSendWr = v; }) .def_prop_rw( "ib_max_recv_wr", [](EndpointConfig& self) { return self.ib.maxRecvWr; }, [](EndpointConfig& self, int v) { self.ib.maxRecvWr = v; }) .def_prop_rw( "ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; }, [](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; }) .def_prop_rw( "ib_mode", [](EndpointConfig& self) { return self.ib.mode; }, [](EndpointConfig& self, EndpointConfig::Ib::Mode v) { self.ib.mode = v; }) .def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize); nb::class_(m, "CppContext") .def_static("create", &Context::create) .def( "register_memory", [](Context* self, uintptr_t ptr, size_t size, TransportFlags transports) { return self->registerMemory((void*)ptr, size, transports); }, nb::arg("ptr"), nb::arg("size"), nb::arg("transports")) .def("create_endpoint", &Context::createEndpoint, nb::arg("config")) .def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint")); nb::class_(m, "CppSemaphoreStub") .def(nb::init(), nb::arg("connection")) .def("memory", &SemaphoreStub::memory) .def("serialize", &SemaphoreStub::serialize) .def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data")); nb::class_(m, "CppSemaphore") .def(nb::init<>()) .def(nb::init(), nb::arg("local_stub"), nb::arg("remote_stub")) .def("connection", &Semaphore::connection) .def("local_memory", &Semaphore::localMemory) .def("remote_memory", &Semaphore::remoteMemory); def_shared_future(m, "RegisteredMemory"); def_shared_future(m, "Connection"); def_shared_future(m, "Semaphore"); nb::class_(m, "CppCommunicator") .def(nb::init, std::shared_ptr>(), nb::arg("bootstrap"), nb::arg("context") = nullptr) .def("bootstrap", &Communicator::bootstrap) .def("context", &Communicator::context) .def( "register_memory", [](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) { return self->registerMemory((void*)ptr, size, transports); }, nb::arg("ptr"), nb::arg("size"), nb::arg("transports")) .def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0) .def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0) .def("connect", static_cast (Communicator::*)(const EndpointConfig&, int, int)>( &Communicator::connect), nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0) .def( "connect_on_setup", [](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) { return self->connect(std::move(localConfig), remoteRank, tag); }, nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config")) .def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag")) .def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag")) .def("build_semaphore", &Communicator::buildSemaphore, nb::arg("connection"), nb::arg("remote_rank"), nb::arg("tag") = 0) .def("remote_rank_of", &Communicator::remoteRankOf) .def("tag_of", &Communicator::tagOf) .def("setup", [](Communicator*) {}); } NB_MODULE(_mscclpp, m) { #ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS nb::set_leak_warnings(false); #endif register_env(m); register_error(m); register_port_channel(m); register_memory_channel(m); register_fifo(m); register_semaphore(m); register_utils(m); register_core(m); register_numa(m); register_nvls(m); register_executor(m); register_npkit(m); register_gpu_utils(m); register_algorithm(m); // ext register_algorithm_collection_builder(m); }