mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-08 15:30:41 +00:00
Change device handle interfaces & others (#142)
* Changed device handle interfaces * Changed proxy service interfaces * Move device code into separate files * Fixed FIFO polling issues * Add configuration arguments in several interface functions --------- Co-authored-by: Changho Hwang <changhohwang@microsoft.com> Co-authored-by: Binyang Li <binyli@microsoft.com> Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
This commit is contained in:
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include <mscclpp/config.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_config(nb::module_& m) {
|
||||
nb::class_<Config>(m, "Config")
|
||||
.def_static("get_instance", &Config::getInstance, nb::rv_policy::reference)
|
||||
.def("get_bootstrap_connection_timeout_config", &Config::getBootstrapConnectionTimeoutConfig)
|
||||
.def("set_bootstrap_connection_timeout_config", &Config::setBootstrapConnectionTimeoutConfig);
|
||||
}
|
||||
@@ -17,8 +17,8 @@ extern void register_proxy_channel(nb::module_& m);
|
||||
extern void register_sm_channel(nb::module_& m);
|
||||
extern void register_fifo(nb::module_& m);
|
||||
extern void register_semaphore(nb::module_& m);
|
||||
extern void register_config(nb::module_& m);
|
||||
extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
@@ -62,9 +62,10 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("nRanks"))
|
||||
.def("create_unique_id", &TcpBootstrap::createUniqueId)
|
||||
.def("get_unique_id", &TcpBootstrap::getUniqueId)
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId)) & TcpBootstrap::initialize, nb::arg("uniqueId"))
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"));
|
||||
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
|
||||
nb::arg("timeoutSec") = 30)
|
||||
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
|
||||
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
|
||||
|
||||
nb::enum_<Transport>(m, "Transport")
|
||||
.value("Unknown", Transport::Unknown)
|
||||
@@ -118,7 +119,7 @@ void register_core(nb::module_& m) {
|
||||
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush)
|
||||
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("remote_rank", &Connection::remoteRank)
|
||||
.def("tag", &Connection::tag)
|
||||
.def("transport", &Connection::transport)
|
||||
@@ -139,7 +140,8 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("transport"))
|
||||
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
|
||||
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
|
||||
.def("setup", &Communicator::setup);
|
||||
}
|
||||
|
||||
@@ -149,7 +151,7 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_sm_channel(m);
|
||||
register_fifo(m);
|
||||
register_semaphore(m);
|
||||
register_config(m);
|
||||
register_utils(m);
|
||||
register_core(m);
|
||||
register_numa(m);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
import torch
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
|
||||
import mscclpp
|
||||
import torch
|
||||
|
||||
IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB0,
|
||||
mscclpp.Transport.IB1,
|
||||
@@ -19,15 +20,19 @@ IB_TRANSPORTS = [
|
||||
mscclpp.Transport.IB7,
|
||||
]
|
||||
|
||||
# Use to hold the sm channels so they don't get garbage collected
|
||||
sm_channels = []
|
||||
|
||||
|
||||
def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
simple_proxy_channels = []
|
||||
sm_semaphores = []
|
||||
connections = []
|
||||
remote_memories = []
|
||||
memory = torch.zeros(element_size, dtype=torch.int32)
|
||||
memory = memory.to("cuda")
|
||||
|
||||
transport_flag = IB_TRANSPORTS[rank] or mscclpp.Transport.CudaIpc
|
||||
transport_flag = mscclpp.TransportFlags(IB_TRANSPORTS[rank]) | mscclpp.Transport.CudaIpc
|
||||
ptr = memory.data_ptr()
|
||||
size = memory.numel() * memory.element_size()
|
||||
reg_mem = comm.register_memory(ptr, size, transport_flag)
|
||||
@@ -42,15 +47,26 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
|
||||
remote_memories.append(remote_mem)
|
||||
comm.setup()
|
||||
|
||||
# Create simple proxy channels
|
||||
for i, conn in enumerate(connections):
|
||||
proxy_channel = mscclpp.SimpleProxyChannel(
|
||||
proxy_service.device_channel(proxy_service.add_semaphore(conn)),
|
||||
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
|
||||
proxy_service.add_memory(remote_memories[i].get()),
|
||||
proxy_service.add_memory(reg_mem),
|
||||
)
|
||||
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
|
||||
comm.setup()
|
||||
return simple_proxy_channels
|
||||
|
||||
# Create sm channels
|
||||
for i, conn in enumerate(connections):
|
||||
sm_chan = mscclpp.SmDevice2DeviceSemaphore(comm, conn)
|
||||
sm_semaphores.append(sm_chan)
|
||||
comm.setup()
|
||||
|
||||
for i, conn in enumerate(sm_semaphores):
|
||||
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
|
||||
sm_channels.append(sm_chan)
|
||||
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
|
||||
|
||||
|
||||
def run(rank, args):
|
||||
@@ -60,7 +76,7 @@ def run(rank, args):
|
||||
boot = mscclpp.TcpBootstrap.create(rank, world_size)
|
||||
boot.initialize(args.if_ip_port_trio)
|
||||
comm = mscclpp.Communicator(boot)
|
||||
proxy_service = mscclpp.ProxyService(comm)
|
||||
proxy_service = mscclpp.ProxyService()
|
||||
|
||||
logging.info("Rank: %d, setting up connections", rank)
|
||||
setup_connections(comm, rank, world_size, args.num_elements, proxy_service)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
config = mscclpp.Config.get_instance()
|
||||
config.set_bootstrap_connection_timeout_config(15)
|
||||
timeout = config.get_bootstrap_connection_timeout_config()
|
||||
assert timeout == 15
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import mscclpp
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.root:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import mscclpp
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
|
||||
import mscclpp
|
||||
|
||||
|
||||
def main():
|
||||
timer = mscclpp.Timer()
|
||||
|
||||
@@ -11,15 +11,20 @@ using namespace mscclpp;
|
||||
void register_fifo(nb::module_& m) {
|
||||
nb::class_<ProxyTrigger>(m, "ProxyTrigger").def_rw("fst", &ProxyTrigger::fst).def_rw("snd", &ProxyTrigger::snd);
|
||||
|
||||
nb::class_<DeviceProxyFifo>(m, "DeviceProxyFifo")
|
||||
.def_rw("triggers", &DeviceProxyFifo::triggers)
|
||||
.def_rw("tail_replica", &DeviceProxyFifo::tailReplica)
|
||||
.def_rw("head", &DeviceProxyFifo::head);
|
||||
nb::class_<FifoDeviceHandle>(m, "FifoDeviceHandle")
|
||||
.def_rw("triggers", &FifoDeviceHandle::triggers)
|
||||
.def_rw("tail_replica", &FifoDeviceHandle::tailReplica)
|
||||
.def_rw("head", &FifoDeviceHandle::head)
|
||||
.def_rw("size", &FifoDeviceHandle::size)
|
||||
.def_prop_ro("raw", [](const FifoDeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<HostProxyFifo>(m, "HostProxyFifo")
|
||||
.def(nb::init<>())
|
||||
.def("poll", &HostProxyFifo::poll, nb::arg("trigger"))
|
||||
.def("pop", &HostProxyFifo::pop)
|
||||
.def("flush_tail", &HostProxyFifo::flushTail, nb::arg("sync") = false)
|
||||
.def("device_fifo", &HostProxyFifo::deviceFifo);
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = 128)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)
|
||||
.def("size", &Fifo::size)
|
||||
.def("device_handle", &Fifo::deviceHandle);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,31 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from ._mscclpp import *
|
||||
import os as _os
|
||||
|
||||
from ._mscclpp import (
|
||||
Communicator,
|
||||
Connection,
|
||||
Fifo,
|
||||
Host2DeviceSemaphore,
|
||||
Host2HostSemaphore,
|
||||
numa,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
SimpleProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
TransportFlags,
|
||||
)
|
||||
|
||||
|
||||
def get_include():
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "include")
|
||||
|
||||
|
||||
def get_lib():
|
||||
"""Return the directory that contains the MSCCL++ headers."""
|
||||
return _os.path.join(_os.path.dirname(__file__), "lib")
|
||||
|
||||
13
python/numa_py.cpp
Normal file
13
python/numa_py.cpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mscclpp {
|
||||
int getDeviceNumaNode(int cudaDev);
|
||||
void numaBind(int node);
|
||||
}; // namespace mscclpp
|
||||
|
||||
void register_numa(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("numa", "numa functions");
|
||||
sub_m.def("get_device_numa_node", &mscclpp::getDeviceNumaNode);
|
||||
sub_m.def("numa_bind", &mscclpp::numaBind);
|
||||
}
|
||||
@@ -16,22 +16,40 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<Communicator&>(), nb::arg("comm"))
|
||||
.def(nb::init<>())
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("connection"))
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("device_channel", &ProxyService::deviceChannel, nb::arg("id"));
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, DeviceProxyFifo>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"));
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
|
||||
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"));
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
|
||||
.def("device_handle", &SimpleProxyChannel::deviceHandle);
|
||||
|
||||
m.def("device_handle", &deviceHandle<ProxyChannel>, nb::arg("proxyChannel"));
|
||||
m.def("device_handle", &deviceHandle<SimpleProxyChannel>, nb::arg("simpleProxyChannel"));
|
||||
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
|
||||
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -20,7 +20,10 @@ void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore::DeviceHandle>(host2DeviceSemaphore, "DeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
.def_rw("expected_inbound_semaphore_id", &Host2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const Host2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
@@ -38,5 +41,8 @@ void register_semaphore(nb::module_& m) {
|
||||
.def_rw("inboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::inboundSemaphoreId)
|
||||
.def_rw("outboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::outboundSemaphoreId)
|
||||
.def_rw("remoteInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::remoteInboundSemaphoreId)
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId);
|
||||
.def_rw("expectedInboundSemaphoreId", &SmDevice2DeviceSemaphore::DeviceHandle::expectedInboundSemaphoreId)
|
||||
.def_prop_ro("raw", [](const SmDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -13,11 +13,23 @@ using namespace mscclpp;
|
||||
void register_sm_channel(nb::module_& m) {
|
||||
nb::class_<SmChannel> smChannel(m, "SmChannel");
|
||||
smChannel
|
||||
.def(nb::init<std::shared_ptr<SmDevice2DeviceSemaphore>, RegisteredMemory, void*, void*>(), nb::arg("semaphore"),
|
||||
nb::arg("dst"), nb::arg("src"), nb::arg("getPacketBuffer"))
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
|
||||
.def("__init__",
|
||||
[](SmChannel* smChannel, std::shared_ptr<SmDevice2DeviceSemaphore> semaphore, RegisteredMemory dst,
|
||||
uintptr_t src, uintptr_t get_packet_buffer) {
|
||||
new (smChannel) SmChannel(semaphore, dst, (void*)src, (void*)get_packet_buffer);
|
||||
})
|
||||
.def("device_handle", &SmChannel::deviceHandle);
|
||||
|
||||
nb::class_<SmChannel::DeviceHandle>(smChannel, "DeviceHandle");
|
||||
|
||||
m.def("device_handle", &deviceHandle<SmChannel>, nb::arg("smChannel"));
|
||||
nb::class_<SmChannel::DeviceHandle>(m, "SmChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("src_", &SmChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
|
||||
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
|
||||
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user