mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Merge latest multinode branch
This commit is contained in:
@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
|
||||
endif()
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -16,14 +16,16 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_algorithm(nb::module_& m) {
|
||||
nb::enum_<CollectiveBufferMode>(m, "CollectiveBufferMode")
|
||||
nb::enum_<CollectiveBufferMode>(m, "CppCollectiveBufferMode")
|
||||
.value("ANY", CollectiveBufferMode::Any)
|
||||
.value("IN_PLACE", CollectiveBufferMode::InPlace)
|
||||
.value("OUT_OF_PLACE", CollectiveBufferMode::OutOfPlace);
|
||||
|
||||
nb::enum_<AlgorithmType>(m, "AlgorithmType").value("NATIVE", AlgorithmType::Native).value("DSL", AlgorithmType::DSL);
|
||||
nb::enum_<AlgorithmType>(m, "CppAlgorithmType")
|
||||
.value("NATIVE", AlgorithmType::Native)
|
||||
.value("DSL", AlgorithmType::DSL);
|
||||
|
||||
nb::enum_<CommResult>(m, "CommResult")
|
||||
nb::enum_<CommResult>(m, "CppCommResult")
|
||||
.value("COMM_SUCCESS", CommResult::CommSuccess)
|
||||
.value("COMM_UNHANDLED_CUDA_ERROR", CommResult::CommUnhandledCudaError)
|
||||
.value("COMM_SYSTEM_ERROR", CommResult::CommSystemError)
|
||||
@@ -34,13 +36,13 @@ void register_algorithm(nb::module_& m) {
|
||||
.value("COMM_IN_PROGRESS", CommResult::CommInProgress)
|
||||
.value("COMM_NUM_RESULTS", CommResult::CommNumResults);
|
||||
|
||||
nb::enum_<ReduceOp>(m, "ReduceOp")
|
||||
nb::enum_<ReduceOp>(m, "CppReduceOp")
|
||||
.value("SUM", ReduceOp::SUM)
|
||||
.value("MIN", ReduceOp::MIN)
|
||||
.value("NOP", ReduceOp::NOP);
|
||||
|
||||
auto algorithmClass =
|
||||
nb::class_<Algorithm>(m, "Algorithm")
|
||||
nb::class_<Algorithm>(m, "CppAlgorithm")
|
||||
.def_static(
|
||||
"from_native_capsule",
|
||||
[](nb::capsule cap) {
|
||||
@@ -58,6 +60,12 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("name", &Algorithm::name)
|
||||
.def_prop_ro("collective", &Algorithm::collective)
|
||||
.def_prop_ro("message_range", &Algorithm::messageRange)
|
||||
.def(
|
||||
"set_message_size_range",
|
||||
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
|
||||
self.setMessageSizeRange(minMessageSize, maxMessageSize);
|
||||
},
|
||||
nb::arg("min_message_size"), nb::arg("max_message_size"))
|
||||
.def_prop_ro("tags", &Algorithm::tags)
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
@@ -67,16 +75,19 @@ void register_algorithm(nb::module_& m) {
|
||||
"execute",
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
|
||||
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, extras);
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
|
||||
static_cast<DataType>(accumDtype));
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>());
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
|
||||
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
|
||||
.def("reset", &Algorithm::reset);
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
.def(nb::init<>())
|
||||
@@ -84,21 +95,21 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_rw("world_size", &Algorithm::Constraint::worldSize)
|
||||
.def_rw("n_ranks_per_node", &Algorithm::Constraint::nRanksPerNode);
|
||||
|
||||
nb::class_<AlgorithmBuilder>(m, "AlgorithmBuilder").def("build", &AlgorithmBuilder::build);
|
||||
nb::class_<AlgorithmBuilder>(m, "CppAlgorithmBuilder").def("build", &AlgorithmBuilder::build);
|
||||
|
||||
nb::class_<DslAlgorithm, Algorithm>(m, "DslAlgorithm")
|
||||
nb::class_<DslAlgorithm, Algorithm>(m, "CppDslAlgorithm")
|
||||
.def(nb::init<std::string, ExecutionPlan, std::unordered_map<std::string, uint64_t>, Algorithm::Constraint>(),
|
||||
nb::arg("id"), nb::arg("plan"), nb::arg("tags") = std::unordered_map<std::string, uint64_t>(),
|
||||
nb::arg("constraint") = Algorithm::Constraint())
|
||||
.def("build", &DslAlgorithm::build);
|
||||
|
||||
nb::class_<AlgorithmCollection>(m, "AlgorithmCollection")
|
||||
nb::class_<AlgorithmCollection>(m, "CppAlgorithmCollection")
|
||||
.def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"),
|
||||
nb::arg("algorithm"))
|
||||
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"))
|
||||
.def("to_list", &AlgorithmCollection::getAllAlgorithms);
|
||||
|
||||
nb::class_<CollectiveRequest>(m, "CollectiveRequest")
|
||||
nb::class_<CollectiveRequest>(m, "CppCollectiveRequest")
|
||||
.def_ro("world_size", &CollectiveRequest::worldSize)
|
||||
.def_ro("n_ranks_per_node", &CollectiveRequest::nRanksPerNode)
|
||||
.def_ro("rank", &CollectiveRequest::rank)
|
||||
@@ -107,8 +118,22 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("output_buffer",
|
||||
[](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
|
||||
.def_ro("message_size", &CollectiveRequest::messageSize)
|
||||
.def_prop_ro("stream", [](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.stream); })
|
||||
.def_prop_ro("collective", [](const CollectiveRequest& self) { return self.collective; })
|
||||
.def_ro("dtype", &CollectiveRequest::dtype)
|
||||
.def_prop_ro("hints", [](const CollectiveRequest& self) { return self.hints; })
|
||||
.def("buffer_mode", &CollectiveRequest::bufferMode);
|
||||
|
||||
m.def(
|
||||
"cpp_get_flag_buffer",
|
||||
[]() {
|
||||
auto [buffer, size] = getFlagBuffer();
|
||||
uintptr_t ptr = reinterpret_cast<uintptr_t>(buffer.get());
|
||||
// Transfer shared_ptr ownership into a capsule so Python's GC manages the lifetime.
|
||||
auto prevent = std::make_unique<std::shared_ptr<void>>(std::move(buffer));
|
||||
nb::capsule owner(prevent.get(), [](void* p) noexcept { delete static_cast<std::shared_ptr<void>*>(p); });
|
||||
prevent.release(); // capsule now owns the pointer
|
||||
return nb::make_tuple(ptr, size, owner);
|
||||
},
|
||||
"Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size, owner).");
|
||||
}
|
||||
@@ -32,21 +32,25 @@ extern void register_algorithm_collection_builder(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("shared_future_") + typestr;
|
||||
std::string pyclass_name = std::string("CppSharedFuture_") + typestr;
|
||||
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
m.def("version", &version);
|
||||
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
nb::enum_<DataType>(m, "CppDataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
nb::class_<Bootstrap>(m, "CppBootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
||||
@@ -71,7 +75,7 @@ void register_core(nb::module_& m) {
|
||||
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId")
|
||||
nb::class_<UniqueId>(m, "CppUniqueId")
|
||||
.def(nb::init<>())
|
||||
.def("__setstate__",
|
||||
[](UniqueId& self, nb::bytes b) {
|
||||
@@ -81,7 +85,7 @@ void register_core(nb::module_& m) {
|
||||
.def("__getstate__",
|
||||
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "CppTcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
@@ -93,7 +97,7 @@ void register_core(nb::module_& m) {
|
||||
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
|
||||
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
nb::enum_<Transport>(m, "CppTransport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("IB0", Transport::IB0)
|
||||
@@ -106,7 +110,7 @@ void register_core(nb::module_& m) {
|
||||
.value("IB7", Transport::IB7)
|
||||
.value("NumTransports", Transport::NumTransports);
|
||||
|
||||
nb::class_<TransportFlags>(m, "TransportFlags")
|
||||
nb::class_<TransportFlags>(m, "CppTransportFlags")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def("has", &TransportFlags::has, nb::arg("transport"))
|
||||
@@ -130,12 +134,12 @@ void register_core(nb::module_& m) {
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::enum_<DeviceType>(m, "DeviceType")
|
||||
nb::enum_<DeviceType>(m, "CppDeviceType")
|
||||
.value("Unknown", DeviceType::Unknown)
|
||||
.value("CPU", DeviceType::CPU)
|
||||
.value("GPU", DeviceType::GPU);
|
||||
|
||||
nb::class_<Device>(m, "Device")
|
||||
nb::class_<Device>(m, "CppDevice")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<DeviceType>(), nb::arg("type"))
|
||||
.def(nb::init<DeviceType, int>(), nb::arg("type"), nb::arg("id") = -1)
|
||||
@@ -147,24 +151,33 @@ void register_core(nb::module_& m) {
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
|
||||
nb::enum_<EndpointConfig::Ib::Mode>(m, "CppIbMode")
|
||||
.value("Default", EndpointConfig::Ib::Mode::Default)
|
||||
.value("Host", EndpointConfig::Ib::Mode::Host)
|
||||
.value("HostNoAtomic", EndpointConfig::Ib::Mode::HostNoAtomic);
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "CppEndpointConfigIb")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int, int, int, int, int, int>(), nb::arg("device_index") = -1,
|
||||
.def(nb::init<int, int, int, int, int, int, int, int, EndpointConfig::Ib::Mode>(), nb::arg("device_index") = -1,
|
||||
nb::arg("port") = EndpointConfig::Ib::DefaultPort,
|
||||
nb::arg("gid_index") = EndpointConfig::Ib::DefaultGidIndex,
|
||||
nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
|
||||
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
||||
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
||||
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend)
|
||||
nb::arg("max_recv_wr") = EndpointConfig::Ib::DefaultMaxRecvWr,
|
||||
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend,
|
||||
nb::arg("mode") = EndpointConfig::Ib::Mode::Default)
|
||||
.def_rw("device_index", &EndpointConfig::Ib::deviceIndex)
|
||||
.def_rw("port", &EndpointConfig::Ib::port)
|
||||
.def_rw("gid_index", &EndpointConfig::Ib::gidIndex)
|
||||
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
||||
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
||||
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
|
||||
.def_rw("max_recv_wr", &EndpointConfig::Ib::maxRecvWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend)
|
||||
.def_rw("mode", &EndpointConfig::Ib::mode);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
nb::class_<RegisteredMemory>(m, "CppRegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
.def("size", &RegisteredMemory::size)
|
||||
@@ -172,7 +185,7 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
nb::class_<Endpoint>(m, "CppEndpoint")
|
||||
.def("config", &Endpoint::config)
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("device", &Endpoint::device)
|
||||
@@ -180,7 +193,7 @@ void register_core(nb::module_& m) {
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
nb::class_<Connection>(m, "CppConnection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
.def(
|
||||
@@ -197,7 +210,7 @@ void register_core(nb::module_& m) {
|
||||
.def("local_device", &Connection::localDevice)
|
||||
.def("get_max_write_queue_size", &Connection::getMaxWriteQueueSize);
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
nb::class_<EndpointConfig>(m, "CppEndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
||||
@@ -223,12 +236,18 @@ void register_core(nb::module_& m) {
|
||||
.def_prop_rw(
|
||||
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_recv_wr", [](EndpointConfig& self) { return self.ib.maxRecvWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxRecvWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
||||
.def_prop_rw(
|
||||
"ib_mode", [](EndpointConfig& self) { return self.ib.mode; },
|
||||
[](EndpointConfig& self, EndpointConfig::Ib::Mode v) { self.ib.mode = v; })
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
nb::class_<Context>(m, "CppContext")
|
||||
.def_static("create", &Context::create)
|
||||
.def(
|
||||
"register_memory",
|
||||
@@ -239,13 +258,13 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
nb::class_<SemaphoreStub>(m, "CppSemaphoreStub")
|
||||
.def(nb::init<const Connection&>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Semaphore>(m, "Semaphore")
|
||||
nb::class_<Semaphore>(m, "CppSemaphore")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
|
||||
.def("connection", &Semaphore::connection)
|
||||
@@ -256,7 +275,7 @@ void register_core(nb::module_& m) {
|
||||
def_shared_future<Connection>(m, "Connection");
|
||||
def_shared_future<Semaphore>(m, "Semaphore");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
nb::class_<Communicator>(m, "CppCommunicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
@@ -289,6 +308,9 @@ void register_core(nb::module_& m) {
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
#ifdef MSCCLPP_DISABLE_NB_LEAK_WARNINGS
|
||||
nb::set_leak_warnings(false);
|
||||
#endif
|
||||
register_env(m);
|
||||
register_error(m);
|
||||
register_port_channel(m);
|
||||
@@ -306,4 +328,4 @@ NB_MODULE(_mscclpp, m) {
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_env(nb::module_& m) {
|
||||
nb::class_<Env>(m, "Env")
|
||||
nb::class_<Env>(m, "CppEnv")
|
||||
.def_ro("debug", &Env::debug)
|
||||
.def_ro("debug_subsys", &Env::debugSubsys)
|
||||
.def_ro("debug_file", &Env::debugFile)
|
||||
@@ -20,9 +20,11 @@ void register_env(nb::module_& m) {
|
||||
.def_ro("socket_family", &Env::socketFamily)
|
||||
.def_ro("socket_ifname", &Env::socketIfname)
|
||||
.def_ro("comm_id", &Env::commId)
|
||||
.def_ro("execution_plan_dir", &Env::executionPlanDir)
|
||||
.def_ro("ibv_mode", &Env::ibvMode)
|
||||
.def_ro("cache_dir", &Env::cacheDir)
|
||||
.def_ro("npkit_dump_dir", &Env::npkitDumpDir)
|
||||
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream);
|
||||
.def_ro("cuda_ipc_use_default_stream", &Env::cudaIpcUseDefaultStream)
|
||||
.def_ro("ib_gid_index", &Env::ibGidIndex);
|
||||
|
||||
m.def("env", &env);
|
||||
}
|
||||
|
||||
@@ -11,18 +11,18 @@ using namespace mscclpp;
|
||||
|
||||
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
|
||||
nb::register_exception_translator( \
|
||||
[](const std::exception_ptr &p, void *payload) { \
|
||||
[](const std::exception_ptr& p, void* payload) { \
|
||||
try { \
|
||||
std::rethrow_exception(p); \
|
||||
} catch (const name_ &e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
|
||||
} catch (const name_& e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject*>(payload), \
|
||||
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
|
||||
} \
|
||||
}, \
|
||||
m.attr(#name_).ptr());
|
||||
|
||||
void register_error(nb::module_ &m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
void register_error(nb::module_& m) {
|
||||
nb::enum_<ErrorCode>(m, "CppErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
|
||||
@@ -15,16 +15,16 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
nb::enum_<PacketType>(m, "CppPacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
nb::class_<ExecutionPlan>(m, "CppExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
|
||||
.def_prop_ro("collective", [](const ExecutionPlan& self) -> std::string { return self.collective(); })
|
||||
.def_prop_ro("min_message_size", [](const ExecutionPlan& self) -> size_t { return self.minMessageSize(); })
|
||||
.def_prop_ro("max_message_size", [](const ExecutionPlan& self) -> size_t { return self.maxMessageSize(); });
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
nb::class_<Executor>(m, "CppExecutor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
@@ -15,7 +16,7 @@ using namespace mscclpp;
|
||||
using namespace mscclpp::collective;
|
||||
|
||||
void register_algorithm_collection_builder(nb::module_& m) {
|
||||
nb::class_<AlgorithmCollectionBuilder>(m, "AlgorithmCollectionBuilder")
|
||||
nb::class_<AlgorithmCollectionBuilder>(m, "CppAlgorithmCollectionBuilder")
|
||||
.def_static("get_instance", &AlgorithmCollectionBuilder::getInstance)
|
||||
.def("add_algorithm_builder", &AlgorithmCollectionBuilder::addAlgorithmBuilder, nb::arg("builder"))
|
||||
.def(
|
||||
@@ -29,6 +30,6 @@ void register_algorithm_collection_builder(nb::module_& m) {
|
||||
nb::arg("selector"))
|
||||
.def("build", &AlgorithmCollectionBuilder::build)
|
||||
.def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms, nb::arg("scratch_buffer"),
|
||||
nb::arg("scratch_buffer_size"), nb::arg("rank"))
|
||||
nb::arg("scratch_buffer_size"), nb::arg("flag_buffer"), nb::arg("flag_buffer_size"), nb::arg("rank"))
|
||||
.def_static("reset", &AlgorithmCollectionBuilder::reset);
|
||||
}
|
||||
@@ -9,7 +9,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger")
|
||||
nb::class_<ProxyTrigger>(m, "CppProxyTrigger")
|
||||
.def_prop_rw(
|
||||
"fst", [](const ProxyTrigger& self) { return self.fst; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.fst = v; })
|
||||
@@ -17,7 +17,7 @@ void register_fifo(nb::module_& m) {
|
||||
"snd", [](const ProxyTrigger& self) { return self.snd; },
|
||||
[](ProxyTrigger& self, uint64_t v) { self.snd = v; });
|
||||
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
nb::class_<FifoDeviceHandle>(m, "CppFifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail", &FifoDeviceHandle::tail)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
@@ -26,7 +26,7 @@ void register_fifo(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
nb::class_<Fifo>(m, "CppFifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
|
||||
@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
} else if (type == "torch.float16") {
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
} else if (type == "torch.float8_e4m3fn") {
|
||||
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
|
||||
} else if (type == "torch.float8_e4m3fnuz") {
|
||||
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2") {
|
||||
return DLDataType{kDLFloat8_e5m2, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2fnuz") {
|
||||
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
|
||||
} else if (type == "torch.uint8") {
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else if (type == "fp8_e4m3b15") {
|
||||
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else {
|
||||
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
|
||||
}
|
||||
@@ -101,7 +114,7 @@ static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::v
|
||||
void register_gpu_utils(nb::module_& m) {
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
|
||||
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
|
||||
nb::class_<GpuBuffer<char>>(m, "CppRawGpuBuffer")
|
||||
.def(nb::init<size_t>(), nb::arg("nelems"))
|
||||
.def("nelems", &GpuBuffer<char>::nelems)
|
||||
.def("bytes", &GpuBuffer<char>::bytes)
|
||||
|
||||
@@ -11,20 +11,20 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_memory_channel(nb::module_& m) {
|
||||
nb::class_<BaseMemoryChannel>(m, "BaseMemoryChannel")
|
||||
nb::class_<BaseMemoryChannel>(m, "CppBaseMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<std::shared_ptr<MemoryDevice2DeviceSemaphore>>(), nb::arg("semaphore"))
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def("device_handle", &BaseMemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "BaseMemoryChannelDeviceHandle")
|
||||
nb::class_<BaseMemoryChannel::DeviceHandle>(m, "CppBaseMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &BaseMemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_prop_ro("raw", [](const BaseMemoryChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<MemoryChannel>(m, "MemoryChannel")
|
||||
nb::class_<MemoryChannel>(m, "CppMemoryChannel")
|
||||
.def(nb::init<>())
|
||||
.def(
|
||||
"__init__",
|
||||
@@ -42,7 +42,7 @@ void register_memory_channel(nb::module_& m) {
|
||||
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
|
||||
.def("device_handle", &MemoryChannel::deviceHandle);
|
||||
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")
|
||||
nb::class_<MemoryChannel::DeviceHandle>(m, "CppMemoryChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &MemoryChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("dst_", &MemoryChannel::DeviceHandle::dst_)
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
void register_npkit(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions");
|
||||
void register_npkit(nb::module_& m) {
|
||||
nb::module_ sub_m = m.def_submodule("cpp_npkit", "NPKit functions");
|
||||
sub_m.def("init", &NpKit::Init);
|
||||
sub_m.def("dump", &NpKit::Dump);
|
||||
sub_m.def("shutdown", &NpKit::Shutdown);
|
||||
|
||||
@@ -6,8 +6,8 @@ int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
void register_numa(nb::module_& m) {
|
||||
nb::module_ sub_m = m.def_submodule("cpp_numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_port_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
nb::class_<BaseProxyService>(m, "CppBaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy, nb::arg("blocking") = false)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "CppProxyService")
|
||||
.def(nb::init<int>(), nb::arg("fifo_size") = DEFAULT_FIFO_SIZE)
|
||||
.def("start_proxy", &ProxyService::startProxy, nb::arg("blocking") = false)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
@@ -31,13 +31,13 @@ void register_port_channel(nb::module_& m) {
|
||||
.def("base_port_channel", &ProxyService::basePortChannel, nb::arg("id"))
|
||||
.def("port_channel", &ProxyService::portChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<BasePortChannel>(m, "BasePortChannel")
|
||||
nb::class_<BasePortChannel>(m, "CppBasePortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BasePortChannel::deviceHandle);
|
||||
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
|
||||
nb::class_<BasePortChannel::DeviceHandle>(m, "CppBasePortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &BasePortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
|
||||
@@ -46,13 +46,13 @@ void register_port_channel(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<PortChannel>(m, "PortChannel")
|
||||
nb::class_<PortChannel>(m, "CppPortChannel")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &PortChannel::deviceHandle);
|
||||
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
|
||||
nb::class_<PortChannel::DeviceHandle>(m, "CppPortChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_id_", &PortChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "CppHost2DeviceSemaphore");
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
@@ -25,7 +25,7 @@ void register_semaphore(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
nb::class_<Host2HostSemaphore>(m, "CppHost2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
@@ -34,7 +34,7 @@ void register_semaphore(nb::module_& m) {
|
||||
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
|
||||
nb::arg("max_spin_count") = 10000000);
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "CppMemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
@@ -43,7 +43,6 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::class_<MemoryDevice2DeviceSemaphore::DeviceHandle>(memoryDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken)
|
||||
.def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken)
|
||||
.def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken)
|
||||
.def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken)
|
||||
.def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
|
||||
@@ -15,11 +15,11 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_nvls(nb::module_& m) {
|
||||
nb::class_<SwitchChannel>(m, "SwitchChannel")
|
||||
nb::class_<SwitchChannel>(m, "CppSwitchChannel")
|
||||
.def("get_device_ptr", [](SwitchChannel* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &SwitchChannel::deviceHandle);
|
||||
|
||||
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
|
||||
nb::class_<SwitchChannel::DeviceHandle>(m, "CppSwitchChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
|
||||
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
|
||||
@@ -28,7 +28,7 @@ void register_nvls(nb::module_& m) {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
nb::class_<NvlsConnection>(m, "CppNvlsConnection")
|
||||
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"));
|
||||
|
||||
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
|
||||
|
||||
Reference in New Issue
Block a user