diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 39138189..2ef01be6 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -7,7 +7,7 @@ FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind. FetchContent_MakeAvailable(nanobind) nanobind_add_module(mscclpp_py core_py.cpp error_py.cpp proxy_channel_py.cpp fifo_py.cpp semaphore_py.cpp - config_py.cpp utils_py.cpp) + config_py.cpp utils_py.cpp sm_channel_py.cpp) set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME mscclpp) target_link_libraries(mscclpp_py PRIVATE mscclpp_static) target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/python/core_py.cpp b/python/core_py.cpp index 3ab4af7b..59dc39cf 100644 --- a/python/core_py.cpp +++ b/python/core_py.cpp @@ -14,6 +14,7 @@ using namespace mscclpp; extern void register_error(nb::module_& m); extern void register_proxy_channel(nb::module_& m); +extern void register_sm_channel(nb::module_& m); extern void register_fifo(nb::module_& m); extern void register_semaphore(nb::module_& m); extern void register_config(nb::module_& m); @@ -145,6 +146,7 @@ void register_core(nb::module_& m) { NB_MODULE(mscclpp, m) { register_error(m); register_proxy_channel(m); + register_sm_channel(m); register_fifo(m); register_semaphore(m); register_config(m); diff --git a/python/examples/bootstrap.py b/python/examples/bootstrap.py index ddc98874..b383222f 100644 --- a/python/examples/bootstrap.py +++ b/python/examples/bootstrap.py @@ -48,7 +48,7 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service): proxy_service.add_memory(remote_memories[i].get()), proxy_service.add_memory(reg_mem), ) - simple_proxy_channels.append(proxy_channel) + simple_proxy_channels.append(mscclpp.device_handle(proxy_channel)) comm.setup() return simple_proxy_channels diff --git a/python/proxy_channel_py.cpp b/python/proxy_channel_py.cpp index 1bf8ed30..57281164 100644 --- a/python/proxy_channel_py.cpp +++ b/python/proxy_channel_py.cpp @@ -29,6 +29,9 @@ void register_proxy_channel(nb::module_& m) { nb::arg("semaphore"), nb::arg("fifo")); nb::class_(m, "SimpleProxyChannel") - .def(nb::init(), nb::arg("proxy_chan"), nb::arg("dst"), nb::arg("src")) - .def(nb::init(), nb::arg("proxy_chan")); + .def(nb::init(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src")) + .def(nb::init(), nb::arg("proxyChan")); + + m.def("device_handle", &deviceHandle, nb::arg("proxyChannel")); + m.def("device_handle", &deviceHandle, nb::arg("simpleProxyChannel")); }; diff --git a/python/sm_channel_py.cpp b/python/sm_channel_py.cpp new file mode 100644 index 00000000..d02ac30e --- /dev/null +++ b/python/sm_channel_py.cpp @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +namespace nb = nanobind; +using namespace mscclpp; + +void register_sm_channel(nb::module_& m) { + nb::class_ smChannel(m, "SmChannel"); + smChannel + .def(nb::init, RegisteredMemory, void*, void*>(), nb::arg("semaphore"), + nb::arg("dst"), nb::arg("src"), nb::arg("getPacketBuffer")) + .def("device_handle", &SmChannel::deviceHandle); + + nb::class_(smChannel, "DeviceHandle"); + + m.def("device_handle", &deviceHandle, nb::arg("smChannel")); +};