Add executor to execute schedule-plan file (#283)

Add executor to execute the JSON schedule file generated by msccl-tools

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Binyang Li
2024-04-19 03:10:41 +08:00
committed by GitHub
parent 9406123711
commit 64d837f9ab
27 changed files with 2857 additions and 3 deletions

View File

@@ -19,6 +19,10 @@ from ._mscclpp import (
TcpBootstrap,
Transport,
TransportFlags,
DataType,
Executor,
ExecutionPlan,
PacketType,
version,
is_nvls_supported,
)

View File

@@ -51,6 +51,7 @@ class CommGroup:
self.communicator = Communicator(self.bootstrap)
self.my_rank = self.bootstrap.get_rank()
self.nranks = self.bootstrap.get_n_ranks()
self.nranks_per_node = self.bootstrap.get_n_ranks_per_node()
def barrier(self):
self.bootstrap.barrier()

View File

@@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
extern void register_executor(nb::module_& m);
template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
@@ -35,6 +36,7 @@ void register_core(nb::module_& m) {
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
@@ -204,4 +206,5 @@ NB_MODULE(_mscclpp, m) {
register_utils(m);
register_core(m);
register_numa(m);
register_executor(m);
}

View File

@@ -16,7 +16,8 @@ void register_error(nb::module_& m) {
.value("RemoteError", ErrorCode::RemoteError)
.value("InvalidUsage", ErrorCode::InvalidUsage)
.value("Timeout", ErrorCode::Timeout)
.value("Aborted", ErrorCode::Aborted);
.value("Aborted", ErrorCode::Aborted)
.value("ExecutorError", ErrorCode::ExecutorError);
nb::class_<BaseError>(m, "BaseError")
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))

View File

@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <mscclpp/executor.hpp>
#include <mscclpp/gpu.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_executor(nb::module_& m) {
nb::enum_<DataType>(m, "DataType")
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
.def(nb::init<const std::string, const std::string>(), nb::arg("name"), nb::arg("planPath"));
nb::class_<Executor>(m, "Executor")
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
.def(
"execute",
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
},
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
nb::arg("packetType") = PacketType::LL16);
}