Refactoring NVLS interfaces (#293)

Move NVLS details from the core to a separate interface
This commit is contained in:
Changho Hwang
2024-04-24 10:05:41 -07:00
committed by GitHub
parent 9934c982a8
commit 6c1fa5307c
8 changed files with 161 additions and 128 deletions

View File

@@ -98,7 +98,7 @@ class CommGroup:
else:
endpoint = endpoints
if endpoint.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoint)
return connect_nvls_collective(self.communicator, all_ranks)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()

View File

@@ -20,6 +20,7 @@ extern void register_fifo(nb::module_& m);
extern void register_semaphore(nb::module_& m);
extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
extern void register_nvls(nb::module_& m);
extern void register_executor(nb::module_& m);
template <typename T>
@@ -128,24 +129,6 @@ void register_core(nb::module_& m) {
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
@@ -154,7 +137,6 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
@@ -191,7 +173,6 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("connct_nvls_collective", &Communicator::connectNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
@@ -206,5 +187,6 @@ NB_MODULE(_mscclpp, m) {
register_utils(m);
register_core(m);
register_numa(m);
register_nvls(m);
register_executor(m);
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/core.hpp>
#include <mscclpp/nvls.hpp>
namespace nb = nanobind;
using namespace mscclpp;
void register_nvls(nb::module_& m) {
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"),
nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize);
}