mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 13:29:45 +00:00
Pytest (#162)
Port python tests to mscclpp. Please run `mpirun -tag-output -np 8 pytest ./python/test/test_mscclpp.py -x` to start pytest --------- Co-authored-by: Saeed Maleki <saemal@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com> Co-authored-by: Saeed Maleki <30272783+saeedmaleki@users.noreply.github.com>
This commit is contained in:
14
python/mscclpp/CMakeLists.txt
Normal file
14
python/mscclpp/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
|
||||
FetchContent_MakeAvailable(nanobind)
|
||||
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
|
||||
nanobind_add_module(mscclpp_py ${SOURCES})
|
||||
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
target_link_libraries(mscclpp_py PRIVATE mscclpp_static)
|
||||
target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
157
python/mscclpp/core_py.cpp
Normal file
157
python/mscclpp/core_py.cpp
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
extern void register_error(nb::module_& m);
|
||||
extern void register_proxy_channel(nb::module_& m);
|
||||
extern void register_sm_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
|
||||
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
|
||||
.def("ready", &NonblockingFuture<T>::ready)
|
||||
.def("get", &NonblockingFuture<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
nb::class_<Bootstrap>(m, "Bootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
.def("get_n_ranks", &Bootstrap::getNranks)
|
||||
.def(
|
||||
"send",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->send(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def(
|
||||
"recv",
|
||||
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
|
||||
void* data = reinterpret_cast<void*>(ptr);
|
||||
self->recv(data, size, peer, tag);
|
||||
},
|
||||
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
|
||||
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
|
||||
.def("barrier", &Bootstrap::barrier)
|
||||
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"))
|
||||
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
|
||||
nb::arg("peer"), nb::arg("tag"));
|
||||
|
||||
nb::class_<UniqueId>(m, "UniqueId");
|
||||
|
||||
nb::class_<TcpBootstrap, Bootstrap>(m, "TcpBootstrap")
|
||||
.def(nb::init<int, int>(), "Do not use this constructor. Use create instead.")
|
||||
.def_static(
|
||||
"create", [](int rank, int nRanks) { return std::make_shared<TcpBootstrap>(rank, nRanks); }, nb::arg("rank"),
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
|
||||
nb::arg("timeoutSec") = 30)
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("IB0", Transport::IB0)
|
||||
.value("IB1", Transport::IB1)
|
||||
.value("IB2", Transport::IB2)
|
||||
.value("IB3", Transport::IB3)
|
||||
.value("IB4", Transport::IB4)
|
||||
.value("IB5", Transport::IB5)
|
||||
.value("IB6", Transport::IB6)
|
||||
.value("IB7", Transport::IB7)
|
||||
.value("NumTransports", Transport::NumTransports);
|
||||
|
||||
nb::class_<TransportFlags>(m, "TransportFlags")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def("has", &TransportFlags::has, nb::arg("transport"))
|
||||
.def("none", &TransportFlags::none)
|
||||
.def("any", &TransportFlags::any)
|
||||
.def("all", &TransportFlags::all)
|
||||
.def("count", &TransportFlags::count)
|
||||
.def(nb::self |= nb::self)
|
||||
.def(nb::self | nb::self)
|
||||
.def(nb::self | Transport())
|
||||
.def(nb::self &= nb::self)
|
||||
.def(nb::self & nb::self)
|
||||
.def(nb::self & Transport())
|
||||
.def(nb::self ^= nb::self)
|
||||
.def(nb::self ^ nb::self)
|
||||
.def(nb::self ^ Transport())
|
||||
.def(~nb::self)
|
||||
.def(nb::self == nb::self)
|
||||
.def(nb::self != nb::self);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
.def("size", &RegisteredMemory::size)
|
||||
.def("rank", &RegisteredMemory::rank)
|
||||
.def("transports", &RegisteredMemory::transports)
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<Connection>(m, "Connection")
|
||||
.def("write", &Connection::write, nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("srcOffset"),
|
||||
nb::arg("size"))
|
||||
.def(
|
||||
"update_and_sync",
|
||||
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("remote_rank", &Connection::remoteRank)
|
||||
.def("tag", &Connection::tag)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>>(), nb::arg("bootstrap"))
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
|
||||
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
|
||||
.def("setup", &Communicator::setup);
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
register_error(m);
|
||||
register_proxy_channel(m);
|
||||
register_sm_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
}
|
||||
41
python/mscclpp/error_py.cpp
Normal file
41
python/mscclpp/error_py.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_error(nb::module_& m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
.value("RemoteError", ErrorCode::RemoteError)
|
||||
.value("InvalidUsage", ErrorCode::InvalidUsage)
|
||||
.value("Timeout", ErrorCode::Timeout)
|
||||
.value("Aborted", ErrorCode::Aborted);
|
||||
|
||||
nb::class_<BaseError>(m, "BaseError")
|
||||
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &BaseError::getErrorCode)
|
||||
.def("what", &BaseError::what);
|
||||
|
||||
nb::class_<Error, BaseError>(m, "Error")
|
||||
.def(nb::init<const std::string&, ErrorCode>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &Error::getErrorCode);
|
||||
|
||||
nb::class_<SysError, BaseError>(m, "SysError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<CudaError, BaseError>(m, "CudaError")
|
||||
.def(nb::init<const std::string&, cudaError_t>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<CuError, BaseError>(m, "CuError")
|
||||
.def(nb::init<const std::string&, CUresult>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
|
||||
nb::class_<IbError, BaseError>(m, "IbError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
}
|
||||
30
python/mscclpp/fifo_py.cpp
Normal file
30
python/mscclpp/fifo_py.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/fifo.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
|
||||
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail_replica", &FifoDeviceHandle::tailReplica)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
.def_rw("size", &FifoDeviceHandle::size)
|
||||
.def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = 128)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)
|
||||
.def("size", &Fifo::size)
|
||||
.def("device_handle", &Fifo::deviceHandle);
|
||||
}
|
||||
13
python/mscclpp/numa_py.cpp
Normal file
13
python/mscclpp/numa_py.cpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mscclpp {
|
||||
int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
55
python/mscclpp/proxy_channel_py.cpp
Normal file
55
python/mscclpp/proxy_channel_py.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/proxy_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_proxy_channel(nb::module_& m) {
|
||||
nb::class_<BaseProxyService>(m, "BaseProxyService")
|
||||
.def("start_proxy", &BaseProxyService::startProxy)
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<>())
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
|
||||
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
|
||||
.def("device_handle", &SimpleProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
|
||||
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
48
python/mscclpp/semaphore_py.cpp
Normal file
48
python/mscclpp/semaphore_py.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("wait", &Host2HostSemaphore::wait);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
|
||||
smDevice2DeviceSemaphore
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
}
|
||||
35
python/mscclpp/sm_channel_py.cpp
Normal file
35
python/mscclpp/sm_channel_py.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/sm_channel.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_sm_channel(nb::module_& m) {
|
||||
nb::class_<SmChannel> smChannel(m, "SmChannel");
|
||||
smChannel
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src, uintptr_t get_packet_buffer) {
|
||||
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
|
||||
})
|
||||
.def("device_handle", &SmChannel::deviceHandle);
|
||||
|
||||
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &SmChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
23
python/mscclpp/utils_py.cpp
Normal file
23
python/mscclpp/utils_py.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_utils(nb::module_& m) {
|
||||
nb::class_<Timer>(m, "Timer")
|
||||
.def(nb::init<int>(), nb::arg("timeout") = -1)
|
||||
.def("elapsed", &Timer::elapsed)
|
||||
.def("set", &Timer::set, nb::arg("timeout"))
|
||||
.def("reset", &Timer::reset)
|
||||
.def("print", &Timer::print, nb::arg("name"));
|
||||
|
||||
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
|
||||
|
||||
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
|
||||
}
|
||||
Reference in New Issue
Block a user