Address comments for PR #692 (#733)

Rename nanobind-exposed C++ types to Cpp*
Replace MSCCLPP_EXECUTION_PLAN_DIR / MSCCLPP_NATIVE_CACHE_DIR with
MSCCLPP_CACHE_DIR across C++ and Python.
This commit is contained in:
Binyang Li
2026-02-03 10:13:20 -08:00
committed by GitHub
parent 03b1936ddb
commit e21513791a
31 changed files with 211 additions and 205 deletions

View File

@@ -32,21 +32,21 @@ extern void register_algorithm_collection_builder(nb::module_& m);
template <typename T>
void def_shared_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("shared_future_") + typestr;
std::string pyclass_name = std::string("CppSharedFuture_") + typestr;
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
}
void register_core(nb::module_& m) {
m.def("version", &version);
nb::enum_<DataType>(m, "DataType")
nb::enum_<DataType>(m, "CppDataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
nb::class_<Bootstrap>(m, "Bootstrap")
nb::class_<Bootstrap>(m, "CppBootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
@@ -71,7 +71,7 @@ void register_core(nb::module_& m) {
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
nb::class_<UniqueId>(m, "UniqueId")
nb::class_<UniqueId>(m, "CppUniqueId")
.def(nb::init<>())
.def("__setstate__",
[](UniqueId& self, nb::bytes b) {
@@ -81,7 +81,7 @@ void register_core(nb::module_& m) {
.def("__getstate__",
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
nb::class_<TcpBootstrap, Bootstrap>(m, "CppTcpBootstrap")
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
.def_static(
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
@@ -93,7 +93,7 @@ void register_core(nb::module_& m) {
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
nb::enum_<Transport>(m, "Transport")
nb::enum_<Transport>(m, "CppTransport")
.value("Unknown", Transport::Unknown)
.value("CudaIpc", Transport::CudaIpc)
.value("IB0", Transport::IB0)
@@ -106,7 +106,7 @@ void register_core(nb::module_& m) {
.value("IB7", Transport::IB7)
.value("NumTransports", Transport::NumTransports);
nb::class_<TransportFlags>(m, "TransportFlags")
nb::class_<TransportFlags>(m, "CppTransportFlags")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def("has", &TransportFlags::has, nb::arg("transport"))
@@ -130,12 +130,12 @@ void register_core(nb::module_& m) {
.def(nb::self == nb::self)
.def(nb::self != nb::self);
nb::enum_<DeviceType>(m, "DeviceType")
nb::enum_<DeviceType>(m, "CppDeviceType")
.value("Unknown", DeviceType::Unknown)
.value("CPU", DeviceType::CPU)
.value("GPU", DeviceType::GPU);
nb::class_<Device>(m, "Device")
nb::class_<Device>(m, "CppDevice")
.def(nb::init<>())
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
@@ -147,7 +147,7 @@ void register_core(nb::module_& m) {
return ss.str();
});
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
nb::class_<EndpointConfig::Ib>(m, "CppEndpointConfigIb")
.def(nb::init<>())
.def(nb::init<int, int, int, int, int, int, int>(), nb::arg("device_index") = -1,
nb::arg("port") = EndpointConfig::Ib::DefaultPort,
@@ -164,7 +164,7 @@ void register_core(nb::module_& m) {
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
nb::class_<RegisteredMemory>(m, "CppRegisteredMemory")
.def(nb::init<>())
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
.def("size", &RegisteredMemory::size)
@@ -172,7 +172,7 @@ void register_core(nb::module_& m) {
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
nb::class_<Endpoint>(m, "Endpoint")
nb::class_<Endpoint>(m, "CppEndpoint")
.def("config", &Endpoint::config)
.def("transport", &Endpoint::transport)
.def("device", &Endpoint::device)
@@ -180,7 +180,7 @@ void register_core(nb::module_& m) {
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
nb::class_<Connection>(m, "Connection")
nb::class_<Connection>(m, "CppConnection")
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
nb::arg("size"))
.def(
@@ -197,7 +197,7 @@ void register_core(nb::module_& m) {
.def("local_device", &Connection::localDevice)
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
nb::class_<EndpointConfig>(m, "EndpointConfig")
nb::class_<EndpointConfig>(m, "CppEndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
@@ -228,7 +228,7 @@ void register_core(nb::module_& m) {
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
nb::class_<Context>(m, "Context")
nb::class_<Context>(m, "CppContext")
.def_static("create", &Context::create)
.def(
"register_memory",
@@ -239,13 +239,13 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
nb::class_<SemaphoreStub>(m, "CppSemaphoreStub")
.def(nb::init<const Connection&>(), nb::arg("connection"))
.def("memory", &SemaphoreStub::memory)
.def("serialize", &SemaphoreStub::serialize)
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
nb::class_<Semaphore>(m, "Semaphore")
nb::class_<Semaphore>(m, "CppSemaphore")
.def(nb::init<>())
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
.def("connection", &Semaphore::connection)
@@ -256,7 +256,7 @@ void register_core(nb::module_& m) {
def_shared_future<Connection>(m, "Connection");
def_shared_future<Semaphore>(m, "Semaphore");
nb::class_<Communicator>(m, "Communicator")
nb::class_<Communicator>(m, "CppCommunicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)