mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 15:24:43 +00:00
Add executor to execute schedule-plan file (#283)
Add executor to execute the JSON schedule file generated by msccl-tools --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -19,6 +19,10 @@ from ._mscclpp import (
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
DataType,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
version,
|
||||
is_nvls_supported,
|
||||
)
|
||||
|
||||
@@ -51,6 +51,7 @@ class CommGroup:
|
||||
self.communicator = Communicator(self.bootstrap)
|
||||
self.my_rank = self.bootstrap.get_rank()
|
||||
self.nranks = self.bootstrap.get_n_ranks()
|
||||
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
|
||||
|
||||
def barrier(self):
|
||||
self.bootstrap.barrier()
|
||||
|
||||
@@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
extern void register_executor(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
@@ -35,6 +36,7 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
@@ -204,4 +206,5 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
register_executor(m);
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@ void register_error(nb::module_& m) {
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
.value("InvalidUsage", ErrorCode::InvalidUsage)
|
||||
.value("Timeout", ErrorCode::Timeout)
|
||||
.value("Aborted", ErrorCode::Aborted);
|
||||
.value("Aborted", ErrorCode::Aborted)
|
||||
.value("ExecutorError", ErrorCode::ExecutorError);
|
||||
|
||||
nb::class_<BaseError>(m, "BaseError")
|
||||
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
|
||||
38
python/mscclpp/executor.cpp
Normal file
38
python/mscclpp/executor.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
.def(nb::init<const std::string, const std::string>(), nb::arg("name"), nb::arg("planPath"));
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
|
||||
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packetType") = PacketType::LL16);
|
||||
}
|
||||
Reference in New Issue
Block a user