// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include #include #include #include #include #include #include namespace nb = nanobind; using namespace mscclpp; using namespace mscclpp::collective; void register_algorithm_collection_builder(nb::module_& m) { nb::class_(m, "CppAlgorithmCollectionBuilder") .def_static("get_instance", &AlgorithmCollectionBuilder::getInstance) .def("add_algorithm_builder", &AlgorithmCollectionBuilder::addAlgorithmBuilder, nb::arg("builder")) .def( "add_dsl_algorithm_builder", [](AlgorithmCollectionBuilder& self, std::shared_ptr algorithm) { self.addAlgorithmBuilder(algorithm); }, nb::arg("algorithm")) .def("set_algorithm_selector", &AlgorithmCollectionBuilder::setAlgorithmSelector, nb::arg("selector")) .def("set_fallback_algorithm_selector", &AlgorithmCollectionBuilder::setFallbackAlgorithmSelector, nb::arg("selector")) .def("build", &AlgorithmCollectionBuilder::build) .def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms, nb::arg("scratch_buffer"), nb::arg("scratch_buffer_size"), nb::arg("flag_buffer"), nb::arg("flag_buffer_size"), nb::arg("rank")) .def_static("reset", &AlgorithmCollectionBuilder::reset); }