mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
Refactoring NVLS interfaces (#293)
Move NVLS details from the core to a separate interface
This commit is contained in:
@@ -98,7 +98,7 @@ class CommGroup:
|
||||
else:
|
||||
endpoint = endpoints
|
||||
if endpoint.transport == Transport.Nvls:
|
||||
return self.communicator.connct_nvls_collective(all_ranks, endpoint)
|
||||
return connect_nvls_collective(self.communicator, all_ranks)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
self.communicator.setup()
|
||||
|
||||
@@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
extern void register_nvls(nb::module_& m);
|
||||
extern void register_executor(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
@@ -128,24 +129,6 @@ void register_core(nb::module_& m) {
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
|
||||
.def("get_device_ptr",
|
||||
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
|
||||
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
|
||||
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
|
||||
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
@@ -154,7 +137,6 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
@@ -191,7 +173,6 @@ void register_core(nb::module_& m) {
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("localConfig"))
|
||||
.def("connct_nvls_collective", &Communicator::connectNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
@@ -206,5 +187,6 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
register_nvls(m);
|
||||
register_executor(m);
|
||||
}
|
||||
|
||||
38
python/mscclpp/nvls_py.cpp
Normal file
38
python/mscclpp/nvls_py.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/operators.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/nvls.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_nvls(nb::module_& m) {
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
|
||||
.def("get_device_ptr",
|
||||
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
|
||||
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
|
||||
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
|
||||
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"),
|
||||
nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize);
|
||||
}
|
||||
Reference in New Issue
Block a user