mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Torch integration (#692)
Reorganize current native algorithm implementation and DSL algorithm implementation. Provide unified API for DSL algo and native algo and provide interface to tune the algo Provide interface for pytorch integration with native API and DSL --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
|
||||
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.9.2)
|
||||
FetchContent_MakeAvailable(nanobind)
|
||||
|
||||
FetchContent_Declare(dlpack
|
||||
@@ -21,6 +21,7 @@ endif()
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
|
||||
nanobind_add_module(mscclpp_py ${SOURCES})
|
||||
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp_static ${GPU_LIBRARIES})
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
113
python/csrc/algorithm.cpp
Normal file
113
python/csrc/algorithm.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_algorithm(nb::module_& m) {
|
||||
nb::enum_<CollectiveBufferMode>(m, "CollectiveBufferMode")
|
||||
.value("ANY", CollectiveBufferMode::Any)
|
||||
.value("IN_PLACE", CollectiveBufferMode::InPlace)
|
||||
.value("OUT_OF_PLACE", CollectiveBufferMode::OutOfPlace);
|
||||
|
||||
nb::enum_<AlgorithmType>(m, "AlgorithmType").value("NATIVE", AlgorithmType::Native).value("DSL", AlgorithmType::DSL);
|
||||
|
||||
nb::enum_<CommResult>(m, "CommResult")
|
||||
.value("COMM_SUCCESS", CommResult::CommSuccess)
|
||||
.value("COMM_UNHANDLED_CUDA_ERROR", CommResult::CommUnhandledCudaError)
|
||||
.value("COMM_SYSTEM_ERROR", CommResult::CommSystemError)
|
||||
.value("COMM_INTERNAL_ERROR", CommResult::CommInternalError)
|
||||
.value("COMM_INVALID_ARGUMENT", CommResult::CommInvalidArgument)
|
||||
.value("COMM_INVALID_USAGE", CommResult::CommInvalidUsage)
|
||||
.value("COMM_REMOTE_ERROR", CommResult::CommRemoteError)
|
||||
.value("COMM_IN_PROGRESS", CommResult::CommInProgress)
|
||||
.value("COMM_NUM_RESULTS", CommResult::CommNumResults);
|
||||
|
||||
nb::enum_<ReduceOp>(m, "ReduceOp")
|
||||
.value("SUM", ReduceOp::SUM)
|
||||
.value("MIN", ReduceOp::MIN)
|
||||
.value("NOP", ReduceOp::NOP);
|
||||
|
||||
auto algorithmClass =
|
||||
nb::class_<Algorithm>(m, "Algorithm")
|
||||
.def_static(
|
||||
"from_native_capsule",
|
||||
[](nb::capsule cap) {
|
||||
const char* name = cap.name();
|
||||
if (name == nullptr || std::strcmp(name, ALGORITHM_NATIVE_CAPSULE_NAME) != 0) {
|
||||
throw nb::type_error("Invalid capsule: expected 'mscclpp::AlgorithmPtr'");
|
||||
}
|
||||
void* data = cap.data();
|
||||
if (data == nullptr) {
|
||||
throw nb::value_error("Failed to get pointer from capsule");
|
||||
}
|
||||
return *static_cast<std::shared_ptr<Algorithm>*>(data);
|
||||
},
|
||||
nb::arg("capsule"))
|
||||
.def_prop_ro("name", &Algorithm::name)
|
||||
.def_prop_ro("collective", &Algorithm::collective)
|
||||
.def_prop_ro("message_range", &Algorithm::messageRange)
|
||||
.def_prop_ro("tags", &Algorithm::tags)
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
.def_prop_ro("type", &Algorithm::type)
|
||||
.def(
|
||||
"execute",
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, extras);
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>());
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int>(), nb::arg("world_size"), nb::arg("n_ranks_per_node"))
|
||||
.def_rw("world_size", &Algorithm::Constraint::worldSize)
|
||||
.def_rw("n_ranks_per_node", &Algorithm::Constraint::nRanksPerNode);
|
||||
|
||||
nb::class_<AlgorithmBuilder>(m, "AlgorithmBuilder").def("build", &AlgorithmBuilder::build);
|
||||
|
||||
nb::class_<DslAlgorithm, Algorithm>(m, "DslAlgorithm")
|
||||
.def(nb::init<std::string, ExecutionPlan, std::unordered_map<std::string, uint64_t>, Algorithm::Constraint>(),
|
||||
nb::arg("id"), nb::arg("plan"), nb::arg("tags") = std::unordered_map<std::string, uint64_t>(),
|
||||
nb::arg("constraint") = Algorithm::Constraint())
|
||||
.def("build", &DslAlgorithm::build);
|
||||
|
||||
nb::class_<AlgorithmCollection>(m, "AlgorithmCollection")
|
||||
.def("register_algorithm", &AlgorithmCollection::registerAlgorithm, nb::arg("collective"), nb::arg("algo_name"),
|
||||
nb::arg("algorithm"))
|
||||
.def("get_algorithms_by_collective", &AlgorithmCollection::getAlgorithmsByCollective, nb::arg("collective"))
|
||||
.def("to_list", &AlgorithmCollection::getAllAlgorithms);
|
||||
|
||||
nb::class_<CollectiveRequest>(m, "CollectiveRequest")
|
||||
.def_ro("world_size", &CollectiveRequest::worldSize)
|
||||
.def_ro("n_ranks_per_node", &CollectiveRequest::nRanksPerNode)
|
||||
.def_ro("rank", &CollectiveRequest::rank)
|
||||
.def_prop_ro("input_buffer",
|
||||
[](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.inputBuffer); })
|
||||
.def_prop_ro("output_buffer",
|
||||
[](const CollectiveRequest& self) { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
|
||||
.def_ro("message_size", &CollectiveRequest::messageSize)
|
||||
.def_prop_ro("collective", [](const CollectiveRequest& self) { return self.collective; })
|
||||
.def_ro("dtype", &CollectiveRequest::dtype)
|
||||
.def_prop_ro("hints", [](const CollectiveRequest& self) { return self.hints; })
|
||||
.def("buffer_mode", &CollectiveRequest::bufferMode);
|
||||
}
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
@@ -26,6 +25,10 @@ extern void register_nvls(nb::module_& m);
|
||||
extern void register_executor(nb::module_& m);
|
||||
extern void register_npkit(nb::module_& m);
|
||||
extern void register_gpu_utils(nb::module_& m);
|
||||
extern void register_algorithm(nb::module_& m);
|
||||
|
||||
// ext
|
||||
extern void register_algorithm_collection_builder(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
@@ -36,6 +39,13 @@ void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
void register_core(nb::module_& m) {
|
||||
m.def("version", &version);
|
||||
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
@@ -61,7 +71,15 @@ void register_core(nb::module_& m) {
|
||||
.def("recv", static_cast<void (Bootstrap::*)(std::vector<char>&, int, int)>(&Bootstrap::recv), nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId");
|
||||
nb::class_<UniqueId>(m, "UniqueId")
|
||||
.def(nb::init<>())
|
||||
.def("__setstate__",
|
||||
[](UniqueId& self, nb::bytes b) {
|
||||
if (nb::len(b) != UniqueIdBytes) throw std::runtime_error("Invalid UniqueId byte size");
|
||||
::memcpy(self.data(), b.c_str(), UniqueIdBytes);
|
||||
})
|
||||
.def("__getstate__",
|
||||
[](const UniqueId& self) { return nb::bytes(reinterpret_cast<const char*>(self.data()), UniqueIdBytes); });
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
@@ -284,4 +302,8 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_executor(m);
|
||||
register_npkit(m);
|
||||
register_gpu_utils(m);
|
||||
register_algorithm(m);
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
|
||||
@@ -15,50 +15,8 @@ namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionRequest>(m, "ExecutionRequest")
|
||||
.def_ro("world_size", &ExecutionRequest::worldSize)
|
||||
.def_ro("n_ranks_per_node", &ExecutionRequest::nRanksPerNode)
|
||||
.def_prop_ro(
|
||||
"input_buffer",
|
||||
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.inputBuffer); })
|
||||
.def_prop_ro(
|
||||
"output_buffer",
|
||||
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
|
||||
.def_ro("message_size", &ExecutionRequest::messageSize)
|
||||
.def_prop_ro("collective", [](ExecutionRequest& self) -> const std::string& { return self.collective; })
|
||||
.def_prop_ro("hints", [](ExecutionRequest& self) { return self.hints; });
|
||||
|
||||
nb::class_<ExecutionPlanHandle>(m, "ExecutionPlanHandle")
|
||||
.def_ro("id", &ExecutionPlanHandle::id)
|
||||
.def_ro("constraint", &ExecutionPlanHandle::constraint)
|
||||
.def_ro("plan", &ExecutionPlanHandle::plan)
|
||||
.def_ro("tags", &ExecutionPlanHandle::tags)
|
||||
.def_static("create", &ExecutionPlanHandle::create, nb::arg("id"), nb::arg("world_size"),
|
||||
nb::arg("nranks_per_node"), nb::arg("plan"),
|
||||
nb::arg("tags") = std::unordered_map<std::string, uint64_t>{});
|
||||
|
||||
nb::class_<ExecutionPlanHandle::Constraint>(m, "ExecutionPlanConstraint")
|
||||
.def_ro("world_size", &ExecutionPlanHandle::Constraint::worldSize)
|
||||
.def_ro("n_ranks_per_node", &ExecutionPlanHandle::Constraint::nRanksPerNode);
|
||||
|
||||
nb::class_<ExecutionPlanRegistry>(m, "ExecutionPlanRegistry")
|
||||
.def_static("get_instance", &ExecutionPlanRegistry::getInstance)
|
||||
.def("register_plan", &ExecutionPlanRegistry::registerPlan, nb::arg("planHandle"))
|
||||
.def("get_plans", &ExecutionPlanRegistry::getPlans, nb::arg("collective"))
|
||||
.def("get", &ExecutionPlanRegistry::get, nb::arg("id"))
|
||||
.def("set_selector", &ExecutionPlanRegistry::setSelector, nb::arg("selector"))
|
||||
.def("set_default_selector", &ExecutionPlanRegistry::setDefaultSelector, nb::arg("selector"))
|
||||
.def("clear", &ExecutionPlanRegistry::clear);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
|
||||
|
||||
34
python/csrc/ext/algorithm_collection_builder_py.cpp
Normal file
34
python/csrc/ext/algorithm_collection_builder_py.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/ext/collectives/algorithm_collection_builder.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
using namespace mscclpp::collective;
|
||||
|
||||
void register_algorithm_collection_builder(nb::module_& m) {
|
||||
nb::class_<AlgorithmCollectionBuilder>(m, "AlgorithmCollectionBuilder")
|
||||
.def_static("get_instance", &AlgorithmCollectionBuilder::getInstance)
|
||||
.def("add_algorithm_builder", &AlgorithmCollectionBuilder::addAlgorithmBuilder, nb::arg("builder"))
|
||||
.def(
|
||||
"add_dsl_algorithm_builder",
|
||||
[](AlgorithmCollectionBuilder& self, std::shared_ptr<DslAlgorithm> algorithm) {
|
||||
self.addAlgorithmBuilder(algorithm);
|
||||
},
|
||||
nb::arg("algorithm"))
|
||||
.def("set_algorithm_selector", &AlgorithmCollectionBuilder::setAlgorithmSelector, nb::arg("selector"))
|
||||
.def("set_fallback_algorithm_selector", &AlgorithmCollectionBuilder::setFallbackAlgorithmSelector,
|
||||
nb::arg("selector"))
|
||||
.def("build", &AlgorithmCollectionBuilder::build)
|
||||
.def("build_default_algorithms", &AlgorithmCollectionBuilder::buildDefaultAlgorithms, nb::arg("scratch_buffer"),
|
||||
nb::arg("scratch_buffer_size"), nb::arg("rank"))
|
||||
.def_static("reset", &AlgorithmCollectionBuilder::reset);
|
||||
}
|
||||
Reference in New Issue
Block a user