NVLS support. (#250)

Co-authored-by: Saeed Maleki <saemal@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Saeed Maleki
2024-02-04 20:46:10 -08:00
committed by GitHub
parent 4eb0a08b8c
commit 91d592dcc0
22 changed files with 1172 additions and 56 deletions

View File

@@ -6,6 +6,7 @@ import os as _os
from ._mscclpp import (
Communicator,
Connection,
EndpointConfig,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
@@ -19,6 +20,7 @@ from ._mscclpp import (
Transport,
TransportFlags,
version,
is_nvls_supported,
)
__version__ = version()

View File

@@ -8,6 +8,7 @@ import cupy as cp
from ._mscclpp import (
Communicator,
Connection,
EndpointConfig,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
@@ -79,15 +80,21 @@ class CommGroup:
assert False # only 8 IBs are supported
def make_connection(
self, remote_ranks: list[int], transports: Transport | dict[int, Transport]
self,
all_ranks: list[int],
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
if endpoints.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
connections = {}
for rank in remote_ranks:
if type(transports) is dict:
transport = transports[rank]
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
transport = transports
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
endpoint = endpoints
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections = {rank: connections[rank].get() for rank in connections}
return connections

View File

@@ -6,6 +6,7 @@
#include <nanobind/stl/array.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/core.hpp>
@@ -72,6 +73,7 @@ void register_core(nb::module_& m) {
nb::enum_<Transport>(m, "Transport")
.value("Unknown", Transport::Unknown)
.value("CudaIpc", Transport::CudaIpc)
.value("Nvls", Transport::Nvls)
.value("IB0", Transport::IB0)
.value("IB1", Transport::IB1)
.value("IB2", Transport::IB2)
@@ -124,6 +126,24 @@ void register_core(nb::module_& m) {
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
.def("get_device_ptr",
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
.def(nb::init<>())
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<NvlsConnection>(m, "NvlsConnection")
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
@@ -132,6 +152,7 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
@@ -168,6 +189,7 @@ void register_core(nb::module_& m) {
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("connct_nvls_collective", &Communicator::connctNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);

View File

@@ -20,4 +20,5 @@ void register_utils(nb::module_& m) {
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
m.def("is_nvls_supported", &isNvlsSupported);
}