mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
NVLS support. (#250)
Co-authored-by: Saeed Maleki <saemal@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import os as _os
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
EndpointConfig,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
@@ -19,6 +20,7 @@ from ._mscclpp import (
|
||||
Transport,
|
||||
TransportFlags,
|
||||
version,
|
||||
is_nvls_supported,
|
||||
)
|
||||
|
||||
__version__ = version()
|
||||
|
||||
@@ -8,6 +8,7 @@ import cupy as cp
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
EndpointConfig,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
@@ -79,15 +80,21 @@ class CommGroup:
|
||||
assert False # only 8 IBs are supported
|
||||
|
||||
def make_connection(
|
||||
self, remote_ranks: list[int], transports: Transport | dict[int, Transport]
|
||||
self,
|
||||
all_ranks: list[int],
|
||||
endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport],
|
||||
) -> dict[int, Connection]:
|
||||
if type(endpoints) is Transport:
|
||||
endpoints = EndpointConfig(endpoints)
|
||||
if endpoints.transport == Transport.Nvls:
|
||||
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
|
||||
connections = {}
|
||||
for rank in remote_ranks:
|
||||
if type(transports) is dict:
|
||||
transport = transports[rank]
|
||||
for rank in all_ranks:
|
||||
if type(endpoints) is dict:
|
||||
endpoint = endpoints[rank]
|
||||
else:
|
||||
transport = transports
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
|
||||
endpoint = endpoints
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
self.communicator.setup()
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
|
||||
@@ -72,6 +73,7 @@ void register_core(nb::module_& m) {
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
.value("CudaIpc", Transport::CudaIpc)
|
||||
.value("Nvls", Transport::Nvls)
|
||||
.value("IB0", Transport::IB0)
|
||||
.value("IB1", Transport::IB1)
|
||||
.value("IB2", Transport::IB2)
|
||||
@@ -124,6 +126,24 @@ void register_core(nb::module_& m) {
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer>(m, "DeviceMulticastPointer")
|
||||
.def("get_device_ptr",
|
||||
[](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); })
|
||||
.def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle);
|
||||
|
||||
nb::class_<NvlsConnection::DeviceMulticastPointer::DeviceHandle>(m, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr)
|
||||
.def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr)
|
||||
.def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize)
|
||||
.def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<NvlsConnection>(m, "NvlsConnection")
|
||||
.def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda)
|
||||
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
@@ -132,6 +152,7 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, size_t>(), nb::arg("transport"), nb::arg("nvlsBufferSize"))
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
@@ -168,6 +189,7 @@ void register_core(nb::module_& m) {
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("localConfig"))
|
||||
.def("connct_nvls_collective", &Communicator::connctNvlsCollective, nb::arg("allRanks"), nb::arg("config"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
|
||||
@@ -20,4 +20,5 @@ void register_utils(nb::module_& m) {
|
||||
nb::class_<ScopedTimer, Timer>(m, "ScopedTimer").def(nb::init<std::string>(), nb::arg("name"));
|
||||
|
||||
m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim"));
|
||||
m.def("is_nvls_supported", &isNvlsSupported);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user