// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include #include #include #include #include #include #include #include #include namespace nb = nanobind; using namespace mscclpp; void register_algorithm(nb::module_& m) { nb::enum_(m, "CppCollectiveBufferMode") .value("ANY", CollectiveBufferMode::Any) .value("IN_PLACE", CollectiveBufferMode::InPlace) .value("OUT_OF_PLACE", CollectiveBufferMode::OutOfPlace); nb::enum_(m, "CppAlgorithmType") .value("NATIVE", AlgorithmType::Native) .value("DSL", AlgorithmType::DSL); nb::enum_(m, "CppCommResult") .value("COMM_SUCCESS", CommResult::CommSuccess) .value("COMM_UNHANDLED_CUDA_ERROR", CommResult::CommUnhandledCudaError) .value("COMM_SYSTEM_ERROR", CommResult::CommSystemError) .value("COMM_INTERNAL_ERROR", CommResult::CommInternalError) .value("COMM_INVALID_ARGUMENT", CommResult::CommInvalidArgument) .value("COMM_INVALID_USAGE", CommResult::CommInvalidUsage) .value("COMM_REMOTE_ERROR", CommResult::CommRemoteError) .value("COMM_IN_PROGRESS", CommResult::CommInProgress) .value("COMM_NUM_RESULTS", CommResult::CommNumResults); nb::enum_(m, "CppReduceOp") .value("SUM", ReduceOp::SUM) .value("MIN", ReduceOp::MIN) .value("NOP", ReduceOp::NOP); auto algorithmClass = nb::class_(m, "CppAlgorithm") .def_static( "from_native_capsule", [](nb::capsule cap) { const char* name = cap.name(); if (name == nullptr || std::strcmp(name, ALGORITHM_NATIVE_CAPSULE_NAME) != 0) { throw nb::type_error("Invalid capsule: expected 'mscclpp::AlgorithmPtr'"); } void* data = cap.data(); if (data == nullptr) { throw nb::value_error("Failed to get pointer from capsule"); } return *static_cast*>(data); }, nb::arg("capsule")) .def_prop_ro("name", &Algorithm::name) .def_prop_ro("collective", &Algorithm::collective) .def_prop_ro("message_range", &Algorithm::messageRange) .def( "set_message_size_range", [](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) { self.setMessageSizeRange(minMessageSize, maxMessageSize); }, nb::arg("min_message_size"), nb::arg("max_message_size")) .def_prop_ro("tags", &Algorithm::tags) .def_prop_ro("buffer_mode", &Algorithm::bufferMode) .def_prop_ro("constraint", &Algorithm::constraint) .def_prop_ro("type", &Algorithm::type) .def( "execute", [](Algorithm& self, std::shared_ptr comm, uintptr_t input, uintptr_t output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream, std::shared_ptr executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory, std::unordered_map extras) { return self.execute(comm, reinterpret_cast(input), reinterpret_cast(output), inputSize, outputSize, dtype, op, reinterpret_cast(stream), executor, nBlocks, nThreadsPerBlock, symmetricMemory, extras); }, nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"), nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr, nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false, nb::arg("extras") = std::unordered_map()) .def("reset", &Algorithm::reset); nb::class_(algorithmClass, "Constraint") .def(nb::init<>()) .def(nb::init(), nb::arg("world_size"), nb::arg("n_ranks_per_node")) .def_rw("world_size", &Algorithm::Constraint::worldSize) .def_rw("n_ranks_per_node", &Algorithm::Constraint::nRanksPerNode); nb::class_(m, "CppAlgorithmBuilder").def("build", &AlgorithmBuilder::build); nb::class_(m, "CppDslAlgorithm") .def(nb::init, Algorithm::Constraint>(), nb::arg("id"), nb::arg("plan"), nb::arg("tags") = std::unordered_map(), nb::arg("constraint") = Algorithm::Constraint()) .def("build", &DslAlgorithm::build); nb::class_(m, "CppAlgorithmCollection") .def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"), nb::arg("algorithm")) .def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective")) .def("to_list", &AlgorithmCollection::getAllAlgorithms); nb::class_(m, "CppCollectiveRequest") .def_ro("world_size", &CollectiveRequest::worldSize) .def_ro("n_ranks_per_node", &CollectiveRequest::nRanksPerNode) .def_ro("rank", &CollectiveRequest::rank) .def_prop_ro("input_buffer", [](const CollectiveRequest& self) { return reinterpret_cast(self.inputBuffer); }) .def_prop_ro("output_buffer", [](const CollectiveRequest& self) { return reinterpret_cast(self.outputBuffer); }) .def_ro("message_size", &CollectiveRequest::messageSize) .def_prop_ro("stream", [](const CollectiveRequest& self) { return reinterpret_cast(self.stream); }) .def_prop_ro("collective", [](const CollectiveRequest& self) { return self.collective; }) .def_ro("dtype", &CollectiveRequest::dtype) .def_prop_ro("hints", [](const CollectiveRequest& self) { return self.hints; }) .def("buffer_mode", &CollectiveRequest::bufferMode); m.def( "cpp_get_flag_buffer", []() { auto [buffer, size] = getFlagBuffer(); uintptr_t ptr = reinterpret_cast(buffer.get()); // Transfer shared_ptr ownership into a capsule so Python's GC manages the lifetime. auto prevent = std::make_unique>(std::move(buffer)); nb::capsule owner(prevent.get(), [](void* p) noexcept { delete static_cast*>(p); }); prevent.release(); // capsule now owns the pointer return nb::make_tuple(ptr, size, owner); }, "Get the default flag buffer. Returns a tuple of (buffer_ptr, buffer_size, owner)."); }